Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,251 Bytes
d46878a 912f740 d46878a 912f740 d46878a 912f740 d46878a 912f740 d46878a 912f740 d46878a 912f740 d46878a 912f740 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import torch
from torch.nn.functional import softmax
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
import os
from time import time
from logger import logger
from time import sleep
mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true'
if not mock_model_call:
use_conda = os.getenv('USE_CONDA', "false") == "true"
device = "cuda"
model_path = os.getenv('MODEL_PATH')#"granite-guardian-3b-pipecleaner-r241024a"
logger.info(f'Model path is "{model_path}"')
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map=device if use_conda else None
)
def generate_text(prompt):
logger.debug('Starting evaluation...')
logger.debug(f'Prompts content is: \n{prompt["content"]}')
mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true'
if mock_model_call:
logger.debug('Returning mocked model result.')
sleep(3)
return {'assessment': 'Yes', 'certainty': 0.97}
else:
start = time()
tokenized_chat = tokenizer.apply_chat_template(
[prompt],
tokenize=True,
add_generation_prompt=True,
return_tensors="pt")
if use_conda:
tokenized_chat = tokenized_chat.to(device)
with torch.no_grad():
logits = model(tokenized_chat).logits
gen_outputs = model.generate(tokenized_chat, max_new_tokens=128)
generated_text = tokenizer.decode(gen_outputs[0])
logger.debug(f'Model generated text: \n{generated_text}')
vocab = tokenizer.get_vocab()
selected_logits = logits[0, -1, [vocab['No'], vocab['Yes']]]
probabilities = softmax(selected_logits, dim=0)
prob = probabilities[1].item()
logger.debug(f'Certainty is: {prob} from probabilities {probabilities}')
certainty = prob
assessment = 'Yes' if certainty > 0.5 else 'No'
certainty = 1 - certainty if certainty < 0.5 else certainty
certainty = f'{round(certainty,3)}'
end = time()
total = end - start
logger.debug(f'it took {round(total/60, 2)} mins')
return {'assessment': assessment, 'certainty': certainty} |