import math import os from time import sleep, time import spaces from ibm_watsonx_ai.client import APIClient from ibm_watsonx_ai.foundation_models import ModelInference from transformers import AutoModelForCausalLM, AutoTokenizer from logger import logger safe_token = "No" risky_token = "Yes" nlogprobs = 20 inference_engine = os.getenv("INFERENCE_ENGINE", "TORCH") logger.debug(f"Inference engine is: '{inference_engine}'") if inference_engine == "TORCH": import torch device = torch.device("cuda") model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.1-8b") logger.debug(f"model_path is {model_path}") tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained(model_path) model = model.to(device).eval() elif inference_engine == "WATSONX": client = APIClient( credentials={"api_key": os.getenv("WATSONX_API_KEY"), "url": "https://us-south.ml.cloud.ibm.com"} ) client.set.default_project(os.getenv("WATSONX_PROJECT_ID")) hf_model_path = "ibm-granite/granite-guardian-3.1-8b" tokenizer = AutoTokenizer.from_pretrained(hf_model_path) model_id = "ibm/granite-guardian-3-8b" # 8B Model: "ibm/granite-guardian-3-8b" model = ModelInference(model_id=model_id, api_client=client) def get_probablities_watsonx(top_tokens_list): safe_token_prob = 1e-50 risky_token_prob = 1e-50 for top_tokens in top_tokens_list: for token in top_tokens: if token["text"].strip().lower() == safe_token.lower(): safe_token_prob += math.exp(token["logprob"]) if token["text"].strip().lower() == risky_token.lower(): risky_token_prob += math.exp(token["logprob"]) probabilities = softmax([math.log(safe_token_prob), math.log(risky_token_prob)]) return probabilities def parse_output_watsonx(generated_tokens_list): label, prob_of_risk = None, None if nlogprobs > 0: top_tokens_list = [generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list] prob = get_probablities_watsonx(top_tokens_list) prob_of_risk = prob[1] res = next(iter(generated_tokens_list))["text"].strip() if risky_token.lower() == res.lower(): label = risky_token elif safe_token.lower() == res.lower(): label = safe_token else: label = "Failed" return label, prob_of_risk def generate_tokens_watsonx(prompt): result = model.generate( prompt=[prompt], params={ "decoding_method": "greedy", "max_new_tokens": 20, "temperature": 0, "return_options": {"token_logprobs": True, "generated_tokens": True, "input_text": True, "top_n_tokens": 5}, }, ) return result[0]["results"][0]["generated_tokens"] def softmax(values): exp_values = [math.exp(v) for v in values] total = sum(exp_values) return [v / total for v in exp_values] def get_probablities(logprobs): safe_token_prob = 1e-50 unsafe_token_prob = 1e-50 for gen_token_i in logprobs: for logprob, index in zip(gen_token_i.values.tolist()[0], gen_token_i.indices.tolist()[0]): decoded_token = tokenizer.convert_ids_to_tokens(index) if decoded_token.strip().lower() == safe_token.lower(): safe_token_prob += math.exp(logprob) if decoded_token.strip().lower() == risky_token.lower(): unsafe_token_prob += math.exp(logprob) probabilities = torch.softmax(torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0) return probabilities def parse_output(output, input_len): label, prob_of_risk = None, None if nlogprobs > 0: list_index_logprobs_i = [ torch.topk(token_i, k=nlogprobs, largest=True, sorted=True) for token_i in list(output.scores)[:-1] ] if list_index_logprobs_i is not None: prob = get_probablities(list_index_logprobs_i) prob_of_risk = prob[1] res = tokenizer.decode(output.sequences[:, input_len:][0], skip_special_tokens=True).strip() if risky_token.lower() == res.lower(): label = risky_token elif safe_token.lower() == res.lower(): label = safe_token else: label = "Failed" return label, prob_of_risk.item() @spaces.GPU def get_prompt(messages, criteria_name, tokenize=False, add_generation_prompt=False, return_tensors=None): logger.debug("Creating prompt for the model.") logger.debug(f"Messages used to create the prompt are: \n{messages}") logger.debug("Criteria name is: " + criteria_name) if criteria_name == "general_harm": criteria_name = "harm" elif criteria_name == "function_calling_hallucination": criteria_name = "function_call" logger.debug("Criteria name was changed too: " + criteria_name) logger.debug(f"Tokenize: {tokenize}") logger.debug(f"add_generation_prompt: {add_generation_prompt}") logger.debug(f"return_tensors: {return_tensors}") guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"} logger.debug(f"guardian_config is: {guardian_config}") prompt = tokenizer.apply_chat_template( messages, guardian_config=guardian_config, tokenize=tokenize, add_generation_prompt=add_generation_prompt, return_tensors=return_tensors, ) logger.debug(f"Prompt (type {type(prompt)}) is: {prompt}") return prompt @spaces.GPU def get_guardian_response(messages, criteria_name): start = time() if inference_engine == "MOCK": logger.debug("Returning mocked model result.") sleep(1) label, prob_of_risk = "Yes", 0.97 elif inference_engine == "WATSONX": chat = get_prompt(messages, criteria_name) logger.debug(f"Prompt is \n{chat}") generated_tokens = generate_tokens_watsonx(chat) label, prob_of_risk = parse_output_watsonx(generated_tokens) elif inference_engine == "TORCH": input_ids = get_prompt( messages=messages, criteria_name=criteria_name, tokenize=True, add_generation_prompt=True, return_tensors="pt", ).to(model.device) # logger.debug(f"input_ids are: {input_ids}") input_len = input_ids.shape[1] logger.debug(f"input_len is: {input_len}") with torch.no_grad(): # output = model.generate(chat, sampling_params, use_tqdm=False) output = model.generate( input_ids, do_sample=False, max_new_tokens=nlogprobs, return_dict_in_generate=True, output_scores=True, ) # logger.debug(f"model output is:\n{output}") label, prob_of_risk = parse_output(output, input_len) logger.debug(f"Label is: {label}") logger.debug(f"Prob_of_risk is: {prob_of_risk}") else: raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, TORCH]") logger.debug(f"Model generated label: {label}") logger.debug(f"Model prob_of_risk: {prob_of_risk}") end = time() total = end - start logger.debug(f"The evaluation took {total} secs") return {"assessment": label, "certainty": prob_of_risk}