File size: 6,710 Bytes
5269ad1
f97dae7
5b7f169
 
 
 
f97dae7
 
026d799
5b7f169
 
182a21a
bce909e
 
 
5269ad1
f97dae7
2b6005c
5269ad1
5b7f169
f97dae7
 
5b7f169
a50a656
5b7f169
 
33193a0
2e81d77
bce909e
026d799
2b6005c
 
2cb730a
f97dae7
5b7f169
 
 
 
 
f97dae7
 
 
5b7f169
 
 
 
2b6005c
 
5269ad1
2b6005c
 
 
 
 
5269ad1
 
2b6005c
f97dae7
 
5269ad1
 
 
 
 
 
 
 
 
 
 
2b6005c
 
5269ad1
2b6005c
f97dae7
2b6005c
5269ad1
2b6005c
 
 
5269ad1
 
 
5b7f169
2b6005c
 
 
 
 
f97dae7
 
 
 
 
5b7f169
 
 
 
f97dae7
 
 
 
 
5b7f169
2b6005c
5b7f169
2e81d77
2b6005c
5b7f169
 
2e81d77
c786139
f97dae7
 
 
 
5b7f169
 
f97dae7
5b7f169
 
 
 
 
f97dae7
 
 
 
 
5b7f169
f97dae7
 
 
5b7f169
f97dae7
 
 
 
 
 
 
 
 
2e81d77
5b7f169
477d968
2e81d77
5b7f169
2e81d77
5b7f169
 
f97dae7
5b7f169
 
 
2b6005c
fb6a6b8
f97dae7
 
5269ad1
5b7f169
2b6005c
fb6a6b8
2b6005c
fb6a6b8
2b6005c
f97dae7
bce909e
2b6005c
 
 
 
 
 
fb6a6b8
2b6005c
 
fb6a6b8
 
f97dae7
 
d46878a
5b7f169
 
 
20a9c66
 
5b7f169
d46878a
5b7f169
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import math
import os
from time import sleep, time

import spaces
import torch
from ibm_watsonx_ai.client import APIClient
from ibm_watsonx_ai.foundation_models import ModelInference
from transformers import AutoModelForCausalLM, AutoTokenizer

from logger import logger

# from vllm import LLM, SamplingParams


safe_token = "No"
risky_token = "Yes"
nlogprobs = 20

inference_engine = os.getenv("INFERENCE_ENGINE", "VLLM")
logger.debug(f"Inference engine is: '{inference_engine}'")

if inference_engine == "VLLM":
    device = torch.device("cuda")

    model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.0-8b")
    logger.debug(f"model_path is {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    # sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
    # model = LLM(model=model_path, tensor_parallel_size=1)
    model = AutoModelForCausalLM.from_pretrained(model_path)
    model = model.to(device).eval()

elif inference_engine == "WATSONX":
    client = APIClient(
        credentials={"api_key": os.getenv("WATSONX_API_KEY"), "url": "https://us-south.ml.cloud.ibm.com"}
    )

    client.set.default_project(os.getenv("WATSONX_PROJECT_ID"))
    hf_model_path = "ibm-granite/granite-guardian-3.0-8b"
    tokenizer = AutoTokenizer.from_pretrained(hf_model_path)

    model_id = "ibm/granite-guardian-3-8b"  # 8B Model: "ibm/granite-guardian-3-8b"
    model = ModelInference(model_id=model_id, api_client=client)


def parse_output(output, input_len):
    label, prob_of_risk = None, None
    if nlogprobs > 0:

        list_index_logprobs_i = [torch.topk(token_i, k=nlogprobs, largest=True, sorted=True)
                                 for token_i in list(output.scores)[:-1]]
        if list_index_logprobs_i is not None:
            prob = get_probablities(list_index_logprobs_i)
            prob_of_risk = prob[1]

    res = tokenizer.decode(output.sequences[:,input_len:][0],skip_special_tokens=True).strip()
    if risky_token.lower() == res.lower():
        label = risky_token
    elif safe_token.lower() == res.lower():
        label = safe_token
    else:
        label = "Failed"

    return label, prob_of_risk.item()

def get_probablities(logprobs):
    safe_token_prob = 1e-50
    unsafe_token_prob = 1e-50
    for gen_token_i in logprobs:
        for logprob, index in zip(gen_token_i.values.tolist()[0], gen_token_i.indices.tolist()[0]):
            decoded_token = tokenizer.convert_ids_to_tokens(index)
            if decoded_token.strip().lower() == safe_token.lower():
                safe_token_prob += math.exp(logprob)
            if decoded_token.strip().lower() == risky_token.lower():
                unsafe_token_prob += math.exp(logprob)

    probabilities = torch.softmax(
        torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0
    )

    return probabilities


def softmax(values):
    exp_values = [math.exp(v) for v in values]
    total = sum(exp_values)
    return [v / total for v in exp_values]

def get_probablities_watsonx(top_tokens_list):
    safe_token_prob = 1e-50
    risky_token_prob = 1e-50
    for top_tokens in top_tokens_list:
        for token in top_tokens:
            if token["text"].strip().lower() == safe_token.lower():
                safe_token_prob += math.exp(token["logprob"])
            if token["text"].strip().lower() == risky_token.lower():
                risky_token_prob += math.exp(token["logprob"])

    probabilities = softmax([math.log(safe_token_prob), math.log(risky_token_prob)])

    return probabilities


def get_prompt(messages, criteria_name, return_tensors=None):
    guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
    return tokenizer.apply_chat_template(
        messages, guardian_config=guardian_config, tokenize=False, add_generation_prompt=True, return_tensors=return_tensors
    )


@spaces.GPU
def generate_tokens(prompt):
    result = model.generate(
        prompt=[prompt],
        params={
            "decoding_method": "greedy",
            "max_new_tokens": 20,
            "temperature": 0,
            "return_options": {"token_logprobs": True, "generated_tokens": True, "input_text": True, "top_n_tokens": 5},
        },
    )
    return result[0]["results"][0]["generated_tokens"]


def parse_output_watsonx(generated_tokens_list):
    label, prob_of_risk = None, None

    if nlogprobs > 0:
        top_tokens_list = [generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list]
        prob = get_probablities_watsonx(top_tokens_list)
        prob_of_risk = prob[1]

    res = next(iter(generated_tokens_list))["text"].strip()

    if risky_token.lower() == res.lower():
        label = risky_token
    elif safe_token.lower() == res.lower():
        label = safe_token
    else:
        label = "Failed"

    return label, prob_of_risk


@spaces.GPU
def generate_text(messages, criteria_name):
    logger.debug(f"Messages used to create the prompt are: \n{messages}")
    start = time()
    if inference_engine == "MOCK":
        logger.debug("Returning mocked model result.")
        sleep(1)
        label, prob_of_risk = "Yes", 0.97

    elif inference_engine == "WATSONX":
        chat = get_prompt(messages, criteria_name)
        logger.debug(f"Prompt is \n{chat}")
        generated_tokens = generate_tokens(chat)
        label, prob_of_risk = parse_output_watsonx(generated_tokens)

    elif inference_engine == "VLLM":
        input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
        logger.debug(f"input_ids are: {input_ids}")
        input_len = input_ids.shape[1]
        logger.debug(f"input_len are: {input_len}")

        with torch.no_grad():
            # output = model.generate(chat, sampling_params, use_tqdm=False)
            output = model.generate(
                chat,
                do_sample=False,
                max_new_tokens=nlogprobs,
                return_dict_in_generate=True,
                output_scores=True,)
            logger.debug(f"model output is are: {output}")

            label, prob_of_risk = parse_output(output, input_len)
            logger.debug(f"label is are: {label}")
            logger.debug(f"prob_of_risk is are: {prob_of_risk}")
    else:
        raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")

    logger.debug(f"Model generated label: \n{label}")
    logger.debug(f"Model prob_of_risk: \n{prob_of_risk}")

    end = time()
    total = end - start
    logger.debug(f"The evaluation took {total} secs")

    return {"assessment": label, "certainty": prob_of_risk}