Spaces:
Running
on
Zero
Running
on
Zero
Martín Santillán Cooper
commited on
Commit
·
fb6a6b8
1
Parent(s):
2b6005c
Add logs
Browse files
model.py
CHANGED
@@ -146,12 +146,7 @@ def parse_output_watsonx(generated_tokens_list):
|
|
146 |
@spaces.GPU
|
147 |
def generate_text(messages, criteria_name):
|
148 |
logger.debug(f"Messages used to create the prompt are: \n{messages}")
|
149 |
-
|
150 |
start = time()
|
151 |
-
|
152 |
-
chat = get_prompt(messages, criteria_name)
|
153 |
-
logger.debug(f"Prompt is \n{chat}")
|
154 |
-
|
155 |
if inference_engine == "MOCK":
|
156 |
logger.debug("Returning mocked model result.")
|
157 |
sleep(1)
|
@@ -159,12 +154,15 @@ def generate_text(messages, criteria_name):
|
|
159 |
|
160 |
elif inference_engine == "WATSONX":
|
161 |
chat = get_prompt(messages, criteria_name)
|
|
|
162 |
generated_tokens = generate_tokens(chat)
|
163 |
label, prob_of_risk = parse_output_watsonx(generated_tokens)
|
164 |
|
165 |
elif inference_engine == "VLLM":
|
166 |
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
|
|
167 |
input_len = input_ids.shape[1]
|
|
|
168 |
|
169 |
with torch.no_grad():
|
170 |
# output = model.generate(chat, sampling_params, use_tqdm=False)
|
@@ -174,8 +172,11 @@ def generate_text(messages, criteria_name):
|
|
174 |
max_new_tokens=nlogprobs,
|
175 |
return_dict_in_generate=True,
|
176 |
output_scores=True,)
|
|
|
177 |
|
178 |
label, prob_of_risk = parse_output(output, input_len)
|
|
|
|
|
179 |
else:
|
180 |
raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")
|
181 |
|
|
|
146 |
@spaces.GPU
|
147 |
def generate_text(messages, criteria_name):
|
148 |
logger.debug(f"Messages used to create the prompt are: \n{messages}")
|
|
|
149 |
start = time()
|
|
|
|
|
|
|
|
|
150 |
if inference_engine == "MOCK":
|
151 |
logger.debug("Returning mocked model result.")
|
152 |
sleep(1)
|
|
|
154 |
|
155 |
elif inference_engine == "WATSONX":
|
156 |
chat = get_prompt(messages, criteria_name)
|
157 |
+
logger.debug(f"Prompt is \n{chat}")
|
158 |
generated_tokens = generate_tokens(chat)
|
159 |
label, prob_of_risk = parse_output_watsonx(generated_tokens)
|
160 |
|
161 |
elif inference_engine == "VLLM":
|
162 |
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
163 |
+
logger.debug(f"input_ids are: {input_ids}")
|
164 |
input_len = input_ids.shape[1]
|
165 |
+
logger.debug(f"input_len are: {input_len}")
|
166 |
|
167 |
with torch.no_grad():
|
168 |
# output = model.generate(chat, sampling_params, use_tqdm=False)
|
|
|
172 |
max_new_tokens=nlogprobs,
|
173 |
return_dict_in_generate=True,
|
174 |
output_scores=True,)
|
175 |
+
logger.debug(f"model output is are: {output}")
|
176 |
|
177 |
label, prob_of_risk = parse_output(output, input_len)
|
178 |
+
logger.debug(f"label is are: {label}")
|
179 |
+
logger.debug(f"prob_of_risk is are: {prob_of_risk}")
|
180 |
else:
|
181 |
raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")
|
182 |
|