Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch.nn.functional import softmax | |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel | |
import os | |
from time import time | |
from logger import logger | |
from time import sleep | |
mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true' | |
if not mock_model_call: | |
use_conda = os.getenv('USE_CONDA', "false") == "true" | |
device = "cuda" | |
model_path = os.getenv('MODEL_PATH')#"granite-guardian-3b-pipecleaner-r241024a" | |
logger.info(f'Model path is "{model_path}"') | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map=device if use_conda else None | |
) | |
def generate_text(prompt): | |
logger.debug('Starting evaluation...') | |
logger.debug(f'Prompts content is: \n{prompt["content"]}') | |
mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true' | |
if mock_model_call: | |
logger.debug('Returning mocked model result.') | |
sleep(3) | |
return {'assessment': 'Yes', 'certainty': 0.97} | |
else: | |
start = time() | |
tokenized_chat = tokenizer.apply_chat_template( | |
[prompt], | |
tokenize=True, | |
add_generation_prompt=True, | |
return_tensors="pt") | |
if use_conda: | |
tokenized_chat = tokenized_chat.to(device) | |
with torch.no_grad(): | |
logits = model(tokenized_chat).logits | |
gen_outputs = model.generate(tokenized_chat, max_new_tokens=128) | |
generated_text = tokenizer.decode(gen_outputs[0]) | |
logger.debug(f'Model generated text: \n{generated_text}') | |
vocab = tokenizer.get_vocab() | |
selected_logits = logits[0, -1, [vocab['No'], vocab['Yes']]] | |
probabilities = softmax(selected_logits, dim=0) | |
prob = probabilities[1].item() | |
logger.debug(f'Certainty is: {prob} from probabilities {probabilities}') | |
certainty = prob | |
assessment = 'Yes' if certainty > 0.5 else 'No' | |
certainty = 1 - certainty if certainty < 0.5 else certainty | |
certainty = f'{round(certainty,3)}' | |
end = time() | |
total = end - start | |
logger.debug(f'it took {round(total/60, 2)} mins') | |
return {'assessment': assessment, 'certainty': certainty} |