File size: 2,251 Bytes
d46878a
 
 
 
 
 
912f740
d46878a
912f740
 
 
 
 
 
 
 
 
 
 
d46878a
 
 
 
912f740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d46878a
912f740
 
 
 
 
d46878a
912f740
 
 
 
 
 
d46878a
912f740
 
 
d46878a
912f740
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from torch.nn.functional import softmax
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
import os
from time import time
from logger import logger
from time import sleep

mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true'
if not mock_model_call:
    use_conda = os.getenv('USE_CONDA', "false") == "true"
    device = "cuda"
    model_path = os.getenv('MODEL_PATH')#"granite-guardian-3b-pipecleaner-r241024a"
    logger.info(f'Model path is "{model_path}"')
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map=device if use_conda else None
    )

def generate_text(prompt):
    logger.debug('Starting evaluation...')
    logger.debug(f'Prompts content is: \n{prompt["content"]}')
    mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true'
    if mock_model_call:
        logger.debug('Returning mocked model result.')
        sleep(3)
        return {'assessment': 'Yes', 'certainty': 0.97}
    else:
        start = time()
        tokenized_chat = tokenizer.apply_chat_template(
            [prompt],
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt")
        if use_conda:
            tokenized_chat = tokenized_chat.to(device)
        with torch.no_grad():
            logits = model(tokenized_chat).logits
            gen_outputs = model.generate(tokenized_chat, max_new_tokens=128)

        generated_text = tokenizer.decode(gen_outputs[0])
        logger.debug(f'Model generated text: \n{generated_text}')
        vocab = tokenizer.get_vocab()
        selected_logits = logits[0, -1, [vocab['No'], vocab['Yes']]]
        probabilities = softmax(selected_logits, dim=0)

        prob = probabilities[1].item()
        logger.debug(f'Certainty is: {prob} from probabilities {probabilities}')
        certainty = prob
        assessment = 'Yes' if certainty > 0.5 else 'No'
        certainty = 1 - certainty if certainty < 0.5 else certainty
        certainty = f'{round(certainty,3)}'

        end = time()
        total = end - start
        logger.debug(f'it took {round(total/60, 2)} mins')

        return {'assessment': assessment, 'certainty': certainty}