llm-explain / src /llm_explain /service /responsible_ai_explain.py
Kiran5's picture
Add application file
563c5bc
'''
Copyright 2024 Infosys Ltd.
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies
or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''
from llm_explain.utility.query_serper import GoogleSerperAPIWrapper
from llm_explain.utility import got as GraphOfThoughts
from llm_explain.utility.prompts.base import Prompt
from llm_explain.config.logger import CustomLogger
from llm_explain.utility.utility import Utils
from llm_explain.utility.azure import Azure
from sklearn.metrics.pairwise import cosine_similarity
from json.decoder import JSONDecodeError
from scipy.stats import gaussian_kde
from itertools import combinations
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from openai import AzureOpenAI
import pandas as pd
import numpy as np
import matplotlib
import asyncio
import base64
import time
import html
import json
import ast
import os
log = CustomLogger()
class ResponsibleAIExplain:
@staticmethod
def llm_response_to_json(response):
"""
Converts a substring of the given response that is in JSON format into a Python dictionary.
This function searches for the first occurrence of '{' and the last occurrence of '}' to find the JSON substring.
It then attempts to parse this substring into a Python dictionary. If the parsing is successful, the dictionary
is returned. If the substring is not valid JSON, the function will return None.
Parameters:
- response (str): The response string that potentially contains JSON content.
Returns:
- dict: A dictionary representation of the JSON substring found within the response.
- None: If no valid JSON substring is found or if an error occurs during parsing.
"""
try:
result = None # Initialize result to None in case no valid JSON is found
# Find the start index of the first '{' character and end index of the last '}' character
start_index = response.find('{')
end_index = response.rfind('}')
# Check if both '{' and '}' are found and '{' comes before '}'
if start_index != -1 and end_index != -1 and end_index > start_index:
json_content = response[start_index:end_index+1] # Extract the substring that is potentially in JSON format
result = json.loads(json_content) # Attempt to parse the JSON substring into a Python dictionary
return result
except Exception as e:
# Log the exception if any error occurs during parsing
log.error(f"An error occurred while parsing JSON from response: {e}", exc_info=True)
raise
async def analyze_heatmap(df_input):
base64_encoded_imgs=[]
try:
df = df_input.copy()
if "token" not in df.columns or "importance_value" not in df.columns:
raise ValueError("The DataFrame must contain 'token' and 'importance_value' columns.")
df["Position"] = range(len(df))
# Calculate histogram data
hist, bin_edges = np.histogram(df["importance_value"], bins=20)
# Get the viridis colormap
viridis = plt.get_cmap("viridis")
# Initialize the figure
fig = go.Figure()
# Create the histogram bars with viridis coloring
for i, freq in enumerate(hist):
color = f"rgb({int(viridis(i / (len(bin_edges) - 1))[0] * 255)}, {int(viridis(i / (len(bin_edges) - 1))[1] * 255)}, {int(viridis(i / (len(bin_edges) - 1))[2] * 255)})"
fig.add_trace(
go.Bar(
x=[(bin_edges[i] + bin_edges[i + 1]) / 2],
y=[freq],
width=np.diff(bin_edges)[i],
marker=dict(color=color),
)
)
# Calculate and add the KDE line
x_kde = np.linspace(min(df["importance_value"]), max(df["importance_value"]), 500)
kde = gaussian_kde(df["importance_value"])
y_kde = kde(x_kde) * sum(hist) * (bin_edges[1] - bin_edges[0])
fig.add_trace(
go.Scatter(
x=x_kde, y=y_kde, mode="lines", line_shape="spline", line=dict(color="red")
)
)
# Additional styling
fig.update_layout(
title=" Distribution of Importance Scores",
title_font={'size': 25},
xaxis_title="Importance Value",
yaxis_title="Frequency",
showlegend=False,
)
img_bytes = fig.to_image(format="png")
img_base64 = base64.b64encode(img_bytes).decode("utf-8")
base64_encoded_imgs.append(img_base64)
# Normalize the importance values
min_val = df["importance_value"].min()
max_val = df["importance_value"].max()
normalized_values = (df["importance_value"] - min_val) / (max_val - min_val)
# Initialize the figure
fig = go.Figure()
# Create the bars, colored based on normalized importance_value
for i, (token, norm_value) in enumerate(zip(df["token"], normalized_values)):
color = f"rgb({int(viridis(norm_value)[0] * 255)}, {int(viridis(norm_value)[1] * 255)}, {int(viridis(norm_value)[2] * 255)})"
fig.add_trace(
go.Bar(
x=[i], # Use index for x-axis
y=[df["importance_value"].iloc[i]],
width=0.9, # Set the width to make bars touch each other
marker=dict(color=color),
)
)
# Additional styling
fig.update_layout(
title="Importance Score per Token",
title_font={'size': 25},
xaxis_title="Token",
yaxis_title="Importance Value",
showlegend=False,
bargap=0, # Remove gap between bars
xaxis=dict( # Set tick labels to tokens
tickmode="array",
tickvals=list(range(len(df["token"]))),
ticktext=list(df["token"]),
),
autosize=False, # Disable automatic sizing
width= max(10, len(df["token"]) * 0.3) * 100, # Convert to pixels
)
# Rotate x-axis labels by 45 degrees
fig.update_xaxes(tickangle=-45)
img_bytes = fig.to_image(format="png")
img_base64 = base64.b64encode(img_bytes).decode("utf-8")
base64_encoded_imgs.append(img_base64)
top_10_important = df.nlargest(10, 'importance_value')
top_10=top_10_important.to_dict(orient='records')
# Extract the importance scores
importance_values = df["importance_value"].values
# Normalize the importance scores to be between 0 and 1
min_val = np.min(importance_values)
max_val = np.max(importance_values)
if max_val - min_val != 0:
normalized_importance_values = (importance_values - min_val) / (max_val - min_val)
else:
normalized_importance_values = np.zeros_like(importance_values)
# Generate a colormap for the heatmap
cmap = matplotlib.colormaps["inferno"]
# Helper function to determine the text color based on the background color
def get_text_color(bg_color):
brightness = 0.299 * bg_color[0] + 0.587 * bg_color[1] + 0.114 * bg_color[2]
if brightness < 0.5:
return "white"
else:
return "black"
# Initialize HTML string
html_string = ""
# Loop over tokens and construct the HTML string
for idx, (token, importance) in df_input.iterrows():
rgba = cmap(normalized_importance_values[idx])
bg_color = rgba[:3]
text_color = get_text_color(bg_color)
# Explicitly handle special characters
token_escaped = html.escape(token).replace('`', '&#96;').replace('$', '&#36;') # Handle backticks and dollar signs
html_string += f"<span style='background-color: rgba({int(bg_color[0]*255)}, {int(bg_color[1]*255)}, {int(bg_color[2]*255)}, 1); color: {text_color};'>{token_escaped}</span> "
return top_10,base64_encoded_imgs,html_string
except Exception as e:
log.error(e, exc_info=True)
raise
async def calculate_uncertainty(n : int, prompt: str):
try:
max_tokens=1000
client = AzureOpenAI(
api_key = os.getenv("AZURE_OPENAI_API_KEY") ,
api_version = os.getenv("AZURE_OPENAI_API_VERSION") ,
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
)
try:
response = client.chat.completions.create(
n=n,
model=os.getenv("AZURE_DEPLOYMENT_ENGINE"), # model = "deployment_name".
messages=[
{"role": "system", "content": "Assistant is a large language model trained by OpenAI."},
{"role": "user", "content": prompt}
],
logprobs=True,
top_logprobs=2,
max_tokens=100
)
except Exception as e:
log.error("An error occurred while calling the AzureOpenAI API", exc_info=True)
raise Exception
cc=response.choices
response_object ={}
choices = []
for i,c in enumerate(cc):
contents=c.logprobs.content
choice_i={
"text": c.message.content
}
logprobs = {}
token_logprobs = []
tokens=[]
top_logprobs=[]
for content in contents:
token_logprobs.append(content.logprob)
temp={}
tokens.append(content.token)
top_props=content.top_logprobs
for k in top_props:
temp[k.token]=k.logprob
top_logprobs.append(temp)
logprobs["token_logprobs"]=token_logprobs
logprobs["tokens"]=tokens
logprobs["top_logprobs"]=top_logprobs
choice_i["logprobs"]=logprobs
choice_i["index"]=i
choices.append(choice_i)
response_object["choices"]=choices
entropies = []
distances = []
choice_embeddings = []
choice_embedding_tasks = [Utils.get_embedding(choice['text']) for choice in response_object['choices']]
choice_embeddings = await asyncio.gather(*choice_embedding_tasks)
async def process_choice(choice, choice_embedding):
top_logprobs_list = choice['logprobs']['top_logprobs']
mean_cosine_distances = []
normalized_entropies = []
tasks = [Utils.process_token_async(i, top_logprobs_list, choice, choice_embedding, max_tokens) for i in range(len(top_logprobs_list))]
results = await asyncio.gather(*tasks)
for mean_distance, normalized_entropy in results:
mean_cosine_distances.append(mean_distance)
normalized_entropies.append(normalized_entropy)
return mean_cosine_distances, normalized_entropies
choice_tasks = [process_choice(choice, emb) for choice, emb in zip(response_object['choices'], choice_embeddings)]
results = await asyncio.gather(*choice_tasks)
for mean_cosine_distances, normalized_entropies in results:
distances.append(mean_cosine_distances)
entropies.append(normalized_entropies)
choice_distances = []
for emb1, emb2 in combinations(choice_embeddings, 2):
cosine_sim = cosine_similarity(emb1.reshape(1, -1), emb2.reshape(1, -1))[0][0]
choice_distances.append(1 - cosine_sim)
mean_choice_distance = np.mean(choice_distances)
uncertainty_scores = {'entropies': entropies, 'distances': distances, 'mean_choice_distance': mean_choice_distance}
return Utils.display_metrics(uncertainty_scores, response_object, n)
except Exception as e:
log.error(e, exc_info=True)
raise Exception
@staticmethod
def normalize_scores(dict_list):
try:
# Calculate the total sum of all importance scores
total_sum = sum(d['importance_score'] for d in dict_list)
# If the total sum is zero, return the original list (to handle cases where all scores are zero)
if total_sum == 0:
return dict_list
# Normalize the scores to ensure their sum equals 100
normalized_scores = [round((d['importance_score'] / total_sum) * 100) for d in dict_list]
# Adjust the scores to ensure their sum is exactly 100
adjustment = 100 - sum(normalized_scores)
normalized_scores[0] += adjustment
# Update the original list with normalized scores
for i, d in enumerate(dict_list):
d['importance_score'] = normalized_scores[i]
return dict_list
except KeyError as e:
log.error(f"KeyError: Missing key in one of the dictionaries - {e}")
raise
except TypeError as e:
log.error(f"TypeError: Invalid type encountered - {e}")
raise
except Exception as e:
log.error(f"An unexpected error occurred: {e}")
raise
@staticmethod
def filter_token_importance(scores, anchors):
import re
try:
# Split each phrase in anchors into individual words, remove special characters, and convert to lowercase
anchors = [re.sub(r'\W+', '', word).lower() for anchor in anchors for word in anchor.split()]
importance_scores = [] # Initialize a list to store the importance scores of the anchors
for score in scores: # Iterate through the scores list
cleaned_token = re.sub(r'\W+', '', str(score['token'])).lower()
if cleaned_token in anchors: # Check if the token value is in the anchors list
importance_scores.append(score['importance_score']) # Append the importance score to the list
# Calculate the remaining importance score
x = 100 - sum(importance_scores)
filtered_tokens = []
for score in scores: # Iterate through the scores list
cleaned_token = re.sub(r'\W+', '', str(score['token'])).lower()
if cleaned_token in anchors: # Check if the token value is in the anchors list
updated_importance = {'token': score['token'],
'importance_score': score['importance_score'] + (x / len(importance_scores)),
'position': score['position']}
filtered_tokens.append(updated_importance) # Append the updated importance score to the new list
return filtered_tokens
except KeyError as e:
log.error(f"KeyError: Missing key in one of the dictionaries - {e}")
raise
except TypeError as e:
log.error(f"TypeError: Invalid type encountered - {e}")
raise
except ZeroDivisionError as e:
log.error(f"ZeroDivisionError: Division by zero encountered - {e}")
raise
except Exception as e:
log.error(f"An unexpected error occurred: {e}")
raise
def sentiment_analysis(text: str, class_names):
log.info("Running local_explain")
try:
start_time = time.time()
explanation = Azure().generate(Prompt.get_classification_prompt(text))
end_time = time.time()
total_time = round(end_time-start_time, 3)
explanation = ResponsibleAIExplain.llm_response_to_json(explanation)
print('explanation', explanation)
# Normalize the importance scores to ensure their sum equals 100
explanation['token_importance_mapping'] = ResponsibleAIExplain.normalize_scores(explanation['token_importance_mapping'])
# Extract the top 10 important tokens
tokens_mapping = ResponsibleAIExplain.filter_token_importance(explanation['token_importance_mapping'], explanation['Keywords'])
return {"predictedTarget": explanation['Sentiment'],
"anchor": explanation['Keywords'],
"explanation": explanation['Explanation'],
"token_importance_mapping": tokens_mapping,
"time_taken": total_time}
except Exception as e:
log.error(e,exc_info=True)
raise
async def local_explanation(prompt: str, response: str):
try:
start_time = time.time()
explanation = Azure().generate(Prompt.get_local_explanation_prompt(prompt, response))
end_time = time.time()
total_time = round(end_time-start_time, 3)
explanation = ResponsibleAIExplain.llm_response_to_json(explanation)
explanation['time_taken'] = total_time
return explanation
except Exception as e:
log.error(e, exc_info=True)
raise
async def process_importance(importance_function, *args, **kwargs):
try:
start_time = time.time()
importance_map = await importance_function(*args, **kwargs)
importance_map_df = pd.DataFrame(importance_map, columns=['token', 'importance_value'])
offset = importance_map_df['importance_value'].mean()
importance_log = Utils.scale_importance_log(
importance_map,
base=None,
offset=offset,
min_percentile=0,
max_percentile=100,
scaling_factor=1,
bias=0
)
importance_log_df = pd.DataFrame(importance_log, columns=['token', 'importance_value'])
end_time = time.time()
total_time = round(end_time-start_time, 3)
return importance_log_df, total_time
except Exception as e:
log.error(e, exc_info=True)
raise
async def prompt_based_token_importance(prompt):
try:
start_time = time.time()
max_retries = 5
attempts = 0
while attempts < max_retries:
try:
explanation = Azure().generate(Prompt.get_token_importance_prompt(prompt))
# Manually find the JSON substring within the mixed content
start_index = explanation.find('{')
end_index = explanation.rfind('}')
if start_index != -1 and end_index != -1 and end_index > start_index:
json_content = explanation[start_index:end_index+1]
result = json.loads(json_content)
# If JSON loads successfully, break out of the loop
break
except JSONDecodeError:
attempts += 1
if attempts == max_retries:
raise Exception("Failed to decode JSON after 5 attempts.")
else:
log.debug(f"JSONDecodeError encountered. Retrying... Attempt {attempts}/{max_retries}")
# Add a delay before the next attempt
time.sleep(2) # Delay for 2 seconds
# Assuming 'result' is a dictionary with "Token" and "Importance Score" as keys, and their values are lists
# First, create a DataFrame from the 'result' dictionary
tokens = result['Token']
scores = result['Importance Score']
positions = list(range(1, len(result['Token']) + 1))
# Find the length of the shortest list
min_length = min(len(tokens), len(scores), len(positions))
# Trim the lists to the length of the shortest list
tokens = tokens[:min_length]
scores = scores[:min_length]
positions = positions[:min_length]
df = pd.DataFrame({
'token': tokens,
'importance_value': scores,
'position': positions
})
df['importance_value'] = df['importance_value'].astype(float)
# Sort the DataFrame by 'Importance Score' in descending order to get the most important tokens first
df_sorted = df.sort_values(by='importance_value', ascending=False)
# Select the top 10 important tokens
df_top10 = df_sorted.head(10)
df_top10.reset_index(drop=True, inplace=True)
end_time = time.time()
total_time = round(end_time-start_time, 3)
top_10, base64_encoded_imgs, token_heatmap = await ResponsibleAIExplain.analyze_heatmap(df_top10[['token', 'importance_value']])
return df_top10.to_dict(orient='records'), base64_encoded_imgs, token_heatmap, total_time
except Exception as e:
log.error(e, exc_info=True)
raise
async def graph_of_thoughts(prompt: str, modelName: str):
try:
start_time = time.time()
budget = 30
task = "answer the following question"
question = prompt
approaches = [GraphOfThoughts.got]
modelName = modelName
formatted_graph, formatted_thoughts = GraphOfThoughts.run(task=task, question=question,
methods=approaches,
budget=budget,
lm_name=modelName)
formatted_graph[3]['operation'] = 'final_thought'
for i in range(4):
thoughts = formatted_graph[i]['thoughts']
for j in range(len(thoughts)):
formatted_graph[i]['thoughts'][j]['score'] = round(formatted_graph[i]['thoughts'][j]['score'], 2)
end_time = time.time()
total_time = round(end_time-start_time, 3)
return formatted_graph, formatted_thoughts, total_time
except Exception as e:
log.error(e, exc_info=True)
raise
async def search_augmentation(inputPrompt, llmResponse):
try:
import datetime
current_date = datetime.datetime.now()
start_time = time.time()
# Step 1: Generate Facts with LLM response
facts = Azure().generate(Prompt.generate_facts_prompt(inputPrompt, llmResponse, current_date))
if isinstance(facts, str):
facts = ResponsibleAIExplain.llm_response_to_json(facts)
facts_list = [fact['Fact'] for fact in facts['Facts']] # Extracting the facts into a list of strings
# Step 2: Filter the facts that are relevant to the input prompt
filtered_facts = Azure().generate(Prompt.filter_facts_prompt(prompt=inputPrompt, facts=facts_list))
filtered_facts = ast.literal_eval(filtered_facts)
filtered_facts_ir = [fact + ' is this statement valid as of today ? why ? #' for fact in filtered_facts]
questions = [inputPrompt] + filtered_facts_ir
# Step 3: Run the prompt and facts through the Google Serper API
search = GoogleSerperAPIWrapper()
internet_responses = await search.run([inputPrompt])
answers = [item[0]['content'] for item in internet_responses]
qa_dict_list_prompt = [{'question': q, 'answer': a} for q, a in zip([inputPrompt], answers)] # Creating the list of dictionaries
internet_responses = await search.run(questions)
answers_facts = [item[0]['content'] for item in internet_responses]
qa_dict_list = [{'question': q, 'answer': a} for q, a in zip(questions, answers_facts)] # Creating the list of dictionaries
if len(facts_list) == 0:
return {'internetResponse': qa_dict_list_prompt[0]['answer'],
'factual_check': {"Score": 0.0,
"explanation_factual_accuracy": {'Result': ['No facts found in the LLM response.']} }}
# Step 4: Summarize the internet responses for prompt and facts
summary_prompt = Azure().generate(Prompt.summarize_prompt(qa_dict_list_prompt))
# Step 5: Evaluate fact with Google Search results
facts = Azure().generate(Prompt.evaluate_facts_prompt(facts=filtered_facts_ir, context=qa_dict_list, prompt=inputPrompt))
if isinstance(facts, str):
facts = ResponsibleAIExplain.llm_response_to_json(facts)
# In facts['Result'], each fact is a dictionary with keys 'Fact', 'Reasoning', and 'Judgement', update Fact with the filtered facts
for i, fact in enumerate(facts['Result']):
fact['Fact'] = filtered_facts[i]
factuality_check = { "Score": 1.0,
"explanation_factual_accuracy": facts }
end_time = time.time()
total_time = round(end_time-start_time, 3)
return {'internetResponse': summary_prompt,
'factual_check': factuality_check,
'time_taken': total_time}
except Exception as e:
log.error(e, exc_info=True)
raise