AML / utils.py
adollbo's picture
added new functions, fixed data processing:
2187b60
raw
history blame
6.44 kB
import streamlit.components.v1 as components
import streamlit as st
from random import randrange, uniform
import pandas as pd
import logging
import numpy as np
COL_NAMES = ['Time step',
'Transaction type',
'Amount transferred',
'Sender\'s initial balance',
'Sender\'s new balance',
'Recipient\'s initial balance',
'Recipient\'s new balance',
"Sender exactly credited",
"Receiver exactly credited",
'Amount > 450 000',
'Frequent receiver',
'Merchant receiver',
'Sender ID',
'Receiver ID']
feature_texts = {0: "Time step: ", 1: "Amount transferred: ", 2: "Initial balance of sender: ", 3: "New balance of sender: ",
4: "Initial balance of recipient: ", 5: "New balance of recipient: ", 6: "Sender's balance was exactly credited: ",
7: "Receiver's balance was exactly credited: ", 8: "Transaction over 450.000: ", 9: "Frequent receiver of transactions: ", 10: "Receiver is merchant: ", 11: "Sender ID: ", 12: "Receiver ID: ",
13: "Transaction type is Cash out", 14: "Transaction type is Transfer", 15: "Transaction type is Payment", 16: "Transaction type is Cash in", 17: "Transaction type is Debit"}
CATEGORIES = np.array(['CASH_OUT', 'TRANSFER', 'PAYMENT', 'CASH_IN', 'DEBIT'])
def transformation(input, categories):
new_x = input
cat = np.array(input[1])
del new_x[1]
result_array = np.zeros(5, dtype=int)
match_index = np.where(categories == cat)[0]
result_array[match_index] = 1
new_x.extend(result_array.tolist())
return new_x
def get_request_body(datapoint):
data = datapoint.iloc[0].tolist()
instances = [int(x) if isinstance(x, (np.int32, np.int64)) else x for x in data]
request_body = {'instances': [instances]}
return request_body
def get_explainability_texts(shap_values, feature_texts):
# Separate positive and negative values, keep indice as corresponds to key
positive_dict = {index: val for index, val in enumerate(shap_values) if val > 0}
# Sort dictionaries based on the magnitude of values
sorted_positive_indices = [index for index, _ in sorted(positive_dict.items(), key=lambda item: abs(item[1]), reverse=True)]
positive_texts = [feature_texts[x] for x in sorted_positive_indices]
positive_texts = positive_texts[2:]
if len(positive_texts) > 5:
positive_texts = positive_texts[:5]
return positive_texts, sorted_positive_indices
def get_explainability_values(pos_indices, datapoint):
data = datapoint.iloc[0].tolist()
rounded_data = [round(value, 2) if isinstance(value, float) else value for value in data]
transformed_data = transformation(input=rounded_data, categories=CATEGORIES)
vals = []
for idx in pos_indices:
if idx in range(7,11) or idx in range(13,18):
val = str(bool(transformed_data[idx])).capitalize()
else:
val = transformed_data[idx]
vals.append(val)
vals = vals[2:]
if len(vals) > 5:
vals = vals[:5]
return vals
def get_fake_certainty():
# Generate a random certainty between 75% and 99%
fake_certainty = uniform(0.75, 0.99)
formatted_fake_certainty = "{:.2%}".format(fake_certainty)
return formatted_fake_certainty
def get_random_suspicious_transaction(data):
suspicious_data=data[data["isFraud"]==1]
max_n=len(suspicious_data)
random_nr=randrange(max_n)
suspicous_transaction = suspicious_data[random_nr-1:random_nr].drop("isFraud", axis=1)
return suspicous_transaction
def send_evaluation(client, deployment_id, request_log_id, prediction_log_id, evaluation_input):
"""Send evaluation to Deeploy."""
try:
with st.spinner("Submitting response..."):
# Call the explain endpoint as it also includes the prediction
client.evaluate(deployment_id, request_log_id, prediction_log_id, evaluation_input)
return True
except Exception as e:
logging.error(e)
st.error(
"Failed to submit feedback."
+ "Check whether you are using the right model URL and Token. "
+ "Contact Deeploy if the problem persists."
)
st.write(f"Error message: {e}")
def get_model_url():
"""Get model url and retrieve workspace id and deployment id from it"""
model_url = st.text_area(
"Model URL (default is the demo deployment)",
"https://api.app.deeploy.ml/workspaces/708b5808-27af-461a-8ee5-80add68384c7/deployments/ac56dbdf-ba04-462f-aa70-5a0d18698e42/",
height=125,
)
elems = model_url.split("/")
try:
workspace_id = elems[4]
deployment_id = elems[6]
except IndexError:
workspace_id = ""
deployment_id = ""
return model_url, workspace_id, deployment_id
def get_comment_explanation(certainty, explainability_texts, explainability_values):
cleaned = [x.replace(':', '') for x in explainability_texts]
fi = [f'{cleaned[i]} is {x}' for i, x in enumerate(explainability_values)]
fi.insert(0, 'Important suspicious features: ')
result = '\n'.join(fi)
comment = f"Model certainty is {certainty}" + '\n''\n' + result
return comment
def create_data_input_table(datapoint, col_names):
st.subheader("Flagged Transaction:")
data = datapoint.iloc[0].tolist()
data[7:12] = [bool(value) for value in data[7:12]]
rounded_list = [round(value, 2) if isinstance(value, float) else value for value in data]
df = pd.DataFrame({"Feature name": col_names, "Value": rounded_list })
st.dataframe(df, hide_index=True, width=450, height=35*len(df)+38)
# Create a function to generate a table
def create_table(texts, values, title):
df = pd.DataFrame({"Feature Explanation": texts, 'Value': values})
st.markdown(f'#### {title}') # Markdown for styling
st.dataframe(df, hide_index=True, width=450) # Display a simple table
def ChangeButtonColour(widget_label, font_color, background_color='transparent'):
htmlstr = f"""
<script>
var elements = window.parent.document.querySelectorAll('button');
for (var i = 0; i < elements.length; ++i) {{
if (elements[i].innerText == '{widget_label}') {{
elements[i].style.color ='{font_color}';
elements[i].style.background = '{background_color}'
}}
}}
</script>
"""
components.html(f"{htmlstr}", height=0, width=0)