File size: 7,316 Bytes
5269ad1
f97dae7
5b7f169
 
 
 
f97dae7
 
026d799
5b7f169
 
182a21a
bce909e
 
 
5269ad1
f97dae7
2b6005c
5269ad1
5b7f169
f97dae7
 
5b7f169
a50a656
5b7f169
2e41a22
33193a0
2e81d77
bce909e
026d799
2b6005c
 
2cb730a
f97dae7
5b7f169
 
 
 
 
2e41a22
f97dae7
 
2e41a22
5b7f169
 
 
2b6005c
 
5269ad1
2b6005c
2e41a22
 
 
2b6005c
 
5269ad1
 
2e41a22
f97dae7
 
5269ad1
 
 
 
 
 
 
2e41a22
5269ad1
 
 
 
2b6005c
 
5269ad1
2b6005c
f97dae7
2b6005c
5269ad1
2e41a22
5269ad1
 
 
5b7f169
2b6005c
 
 
 
 
2e41a22
f97dae7
 
 
 
 
5b7f169
 
 
 
f97dae7
 
 
 
 
5b7f169
4bd44f6
5b7f169
4bd44f6
 
 
 
 
2e41a22
5b7f169
2e41a22
4bd44f6
5b7f169
2e81d77
c786139
f97dae7
 
 
 
5b7f169
 
f97dae7
5b7f169
 
 
 
 
f97dae7
 
 
 
 
5b7f169
f97dae7
 
 
5b7f169
f97dae7
 
 
 
 
 
 
 
 
2e81d77
5b7f169
477d968
2e81d77
5b7f169
2e81d77
5b7f169
 
f97dae7
5b7f169
 
 
2b6005c
fb6a6b8
f97dae7
 
5269ad1
5b7f169
4bd44f6
2e41a22
 
4bd44f6
 
 
 
2e41a22
4bd44f6
2e41a22
4bd44f6
fb6a6b8
2b6005c
fb6a6b8
2b6005c
f97dae7
bce909e
2b6005c
42e2cdc
2b6005c
 
 
2e41a22
 
51c0b7a
2b6005c
 
fb6a6b8
 
f97dae7
 
d46878a
5b7f169
 
 
20a9c66
 
5b7f169
d46878a
5b7f169
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import math
import os
from time import sleep, time

import spaces
import torch
from ibm_watsonx_ai.client import APIClient
from ibm_watsonx_ai.foundation_models import ModelInference
from transformers import AutoModelForCausalLM, AutoTokenizer

from logger import logger

# from vllm import LLM, SamplingParams


safe_token = "No"
risky_token = "Yes"
nlogprobs = 20

inference_engine = os.getenv("INFERENCE_ENGINE", "VLLM")
logger.debug(f"Inference engine is: '{inference_engine}'")

if inference_engine == "VLLM":
    device = torch.device("cuda")

    model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.0-2b")
    logger.debug(f"model_path is {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    # sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
    # model = LLM(model=model_path, tensor_parallel_size=1)
    model = AutoModelForCausalLM.from_pretrained(model_path)
    model = model.to(device).eval()

elif inference_engine == "WATSONX":
    client = APIClient(
        credentials={"api_key": os.getenv("WATSONX_API_KEY"), "url": "https://us-south.ml.cloud.ibm.com"}
    )

    client.set.default_project(os.getenv("WATSONX_PROJECT_ID"))
    hf_model_path = "ibm-granite/granite-guardian-3.0-2b"
    tokenizer = AutoTokenizer.from_pretrained(hf_model_path)

    model_id = "ibm/granite-guardian-3-2b"  # 2b Model: "ibm/granite-guardian-3-2b"
    model = ModelInference(model_id=model_id, api_client=client)


def parse_output(output, input_len):
    label, prob_of_risk = None, None
    if nlogprobs > 0:

        list_index_logprobs_i = [
            torch.topk(token_i, k=nlogprobs, largest=True, sorted=True) for token_i in list(output.scores)[:-1]
        ]
        if list_index_logprobs_i is not None:
            prob = get_probablities(list_index_logprobs_i)
            prob_of_risk = prob[1]

    res = tokenizer.decode(output.sequences[:, input_len:][0], skip_special_tokens=True).strip()
    if risky_token.lower() == res.lower():
        label = risky_token
    elif safe_token.lower() == res.lower():
        label = safe_token
    else:
        label = "Failed"

    return label, prob_of_risk.item()


def get_probablities(logprobs):
    safe_token_prob = 1e-50
    unsafe_token_prob = 1e-50
    for gen_token_i in logprobs:
        for logprob, index in zip(gen_token_i.values.tolist()[0], gen_token_i.indices.tolist()[0]):
            decoded_token = tokenizer.convert_ids_to_tokens(index)
            if decoded_token.strip().lower() == safe_token.lower():
                safe_token_prob += math.exp(logprob)
            if decoded_token.strip().lower() == risky_token.lower():
                unsafe_token_prob += math.exp(logprob)

    probabilities = torch.softmax(torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0)

    return probabilities


def softmax(values):
    exp_values = [math.exp(v) for v in values]
    total = sum(exp_values)
    return [v / total for v in exp_values]


def get_probablities_watsonx(top_tokens_list):
    safe_token_prob = 1e-50
    risky_token_prob = 1e-50
    for top_tokens in top_tokens_list:
        for token in top_tokens:
            if token["text"].strip().lower() == safe_token.lower():
                safe_token_prob += math.exp(token["logprob"])
            if token["text"].strip().lower() == risky_token.lower():
                risky_token_prob += math.exp(token["logprob"])

    probabilities = softmax([math.log(safe_token_prob), math.log(risky_token_prob)])

    return probabilities


def get_prompt(messages, criteria_name, tokenize=False, add_generation_prompt=False, return_tensors=None):
    guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
    prompt = tokenizer.apply_chat_template(
        messages,
        guardian_config=guardian_config,
        tokenize=tokenize,
        add_generation_prompt=add_generation_prompt,
        return_tensors=return_tensors,
    )
    logger.debug(f"prompt is\n{prompt}")
    return prompt


@spaces.GPU
def generate_tokens(prompt):
    result = model.generate(
        prompt=[prompt],
        params={
            "decoding_method": "greedy",
            "max_new_tokens": 20,
            "temperature": 0,
            "return_options": {"token_logprobs": True, "generated_tokens": True, "input_text": True, "top_n_tokens": 5},
        },
    )
    return result[0]["results"][0]["generated_tokens"]


def parse_output_watsonx(generated_tokens_list):
    label, prob_of_risk = None, None

    if nlogprobs > 0:
        top_tokens_list = [generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list]
        prob = get_probablities_watsonx(top_tokens_list)
        prob_of_risk = prob[1]

    res = next(iter(generated_tokens_list))["text"].strip()

    if risky_token.lower() == res.lower():
        label = risky_token
    elif safe_token.lower() == res.lower():
        label = safe_token
    else:
        label = "Failed"

    return label, prob_of_risk


@spaces.GPU
def generate_text(messages, criteria_name):
    logger.debug(f"Messages used to create the prompt are: \n{messages}")
    start = time()
    if inference_engine == "MOCK":
        logger.debug("Returning mocked model result.")
        sleep(1)
        label, prob_of_risk = "Yes", 0.97

    elif inference_engine == "WATSONX":
        chat = get_prompt(messages, criteria_name)
        logger.debug(f"Prompt is \n{chat}")
        generated_tokens = generate_tokens(chat)
        label, prob_of_risk = parse_output_watsonx(generated_tokens)

    elif inference_engine == "VLLM":
        # input_ids = get_prompt(
        #     messages=messages,
        #     criteria_name=criteria_name,
        #     tokenize=True,
        #     add_generation_prompt=True,
        #     return_tensors="pt").to(model.device)
        guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
        logger.debug(f"guardian_config is: {guardian_config}")
        input_ids = tokenizer.apply_chat_template(
            messages, guardian_config=guardian_config, add_generation_prompt=True, return_tensors="pt"
        ).to(model.device)
        logger.debug(f"input_ids are: {input_ids}")
        input_len = input_ids.shape[1]
        logger.debug(f"input_len are: {input_len}")

        with torch.no_grad():
            # output = model.generate(chat, sampling_params, use_tqdm=False)
            output = model.generate(
                input_ids,
                do_sample=False,
                max_new_tokens=nlogprobs,
                return_dict_in_generate=True,
                output_scores=True,
            )
            logger.debug(f"model output is:\n{output}")

            label, prob_of_risk = parse_output(output, input_len)
            logger.debug(f"label is are: {label}")
            logger.debug(f"prob_of_risk is are: {prob_of_risk}")
    else:
        raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")

    logger.debug(f"Model generated label: \n{label}")
    logger.debug(f"Model prob_of_risk: \n{prob_of_risk}")

    end = time()
    total = end - start
    logger.debug(f"The evaluation took {total} secs")

    return {"assessment": label, "certainty": prob_of_risk}