Spaces:
Sleeping
Sleeping
# type: ignore -- ignores linting import issues when using multiple virtual environments | |
import streamlit.components.v1 as components | |
import streamlit as st | |
import pandas as pd | |
import logging | |
from deeploy import Client | |
from shap import TreeExplainer | |
# reset Plotly theme after streamlit import | |
import plotly.io as pio | |
pio.templates.default = "plotly" | |
logging.basicConfig(level=logging.INFO) | |
st.set_page_config(layout="wide") | |
st.title("Your title") | |
st.markdown( | |
""" | |
<style> | |
section[data-testid="stSidebar"] { | |
width: 300px !important; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) # Set the side bar width to fit the Deeploy logo | |
def get_model_url(): | |
"""Function to get Deeploy model URL and split it into workspace and deployment ID.""" | |
model_url = st.text_area( | |
"Model URL (without the /explain endpoint, default is the demo deployment)", | |
"https://api.app.deeploy.ml/workspaces/708b5808-27af-461a-8ee5-80add68384c7/deployments/9155091a-0abb-45b3-8b3b-24ac33fa556b/", | |
height=125, | |
) | |
elems = model_url.split("/") | |
try: | |
workspace_id = elems[4] | |
deployment_id = elems[6] | |
except IndexError: | |
workspace_id = "" | |
deployment_id = "" | |
return model_url, workspace_id, deployment_id | |
def ChangeButtonColour(widget_label, font_color, background_color="transparent"): | |
"""Function to change the color of a button (after it is defined).""" | |
htmlstr = f""" | |
<script> | |
var elements = window.parent.document.querySelectorAll('button'); | |
for (var i = 0; i < elements.length; ++i) {{ | |
if (elements[i].innerText == '{widget_label}') {{ | |
elements[i].style.color ='{font_color}'; | |
elements[i].style.background = '{background_color}' | |
}} | |
}} | |
</script> | |
""" | |
components.html(f"{htmlstr}", height=0, width=0) | |
def predict(): | |
with st.spinner("Loading prediction and explanation..."): | |
try: | |
# Call the explain endpoint as it also includes the prediction | |
exp = client.predict( | |
request_body=request_body, deployment_id=deployment_id | |
) | |
except Exception as e: | |
logging.error(e) | |
st.error( | |
"Failed to get prediction." | |
+ "Check whether you are using the right model URL and token for predictions. " | |
+ "Contact Deeploy if the problem persists." | |
) | |
return | |
st.session_state.exp = exp | |
st.session_state.evaluation_submitted = False | |
hide_expander() | |
def hide_expander(): | |
st.session_state.expander_toggle = False | |
def show_expander(): | |
st.session_state.expander_toggle = True | |
def submit_and_clear(evaluation: str): | |
if evaluation == "yes": | |
st.session_state.evaluation_input["result"] = 0 # Agree with the prediction | |
else: | |
desired_output = not predictions[0] | |
st.session_state.evaluation_input["result"] = 1 | |
st.session_state.evaluation_input["value"] = {"predictions": [desired_output]} | |
try: | |
# Call the explain endpoint as it also includes the prediction | |
client.evaluate( | |
deployment_id, request_log_id, prediction_log_id, st.session_state.evaluation_input | |
) | |
st.session_state.evaluation_submitted = True | |
st.session_state.exp = None | |
show_expander() | |
except Exception as e: | |
logging.error(e) | |
st.error( | |
"Failed to submit feedback." | |
+ "Check whether you are using the right model URL and token for evaluations. " | |
+ "Contact Deeploy if the problem persists." | |
) | |
# Define defaults for the session state | |
if "expander_toggle" not in st.session_state: | |
st.session_state.expander_toggle = True | |
if "exp" not in st.session_state: | |
st.session_state.exp = None | |
if "evaluation_submitted" not in st.session_state: | |
st.session_state.evaluation_submitted = False | |
# Define sidebar for configuration of Deeploy connection | |
with st.sidebar: | |
st.image("deeploy_logo_wide.png", width=250) | |
# Ask for model URL and token | |
host = st.text_input("Host (Changing is optional)", "app.deeploy.ml") | |
model_url, workspace_id, deployment_id = get_model_url() | |
deployment_token = st.text_input("Deeploy Model Token", "my-secret-token") | |
if deployment_token == "my-secret-token": | |
st.warning("Please enter Deeploy API token.") | |
# In case you need to debug the workspace and deployment ID: | |
# st.write("Values below are for debug only:") | |
# st.write("Workspace ID: ", workspace_id) | |
# st.write("Deployment ID: ", deployment_id) | |
client_options = { | |
"host": host, | |
"deployment_token": deployment_token, | |
"workspace_id": workspace_id, | |
} | |
client = Client(**client_options) | |
# For debugging the session state you can uncomment the following lines: | |
# with st.expander("Debug session state", expanded=False): | |
# st.write(st.session_state) | |
# Input (for IRIS dataset) | |
with st.expander("Input values for prediction", expanded=st.session_state.expander_toggle): | |
st.write("Please input the values for the model.") | |
col1, col2 = st.columns(2) | |
with col1: | |
sep_len = st.number_input("Sepal length", value=1.0, step=0.1, key="Sepal length") | |
sep_wid = st.number_input("Sepal width", value=1.0, step=0.1, key="Sepal width") | |
with col2: | |
pet_len = st.number_input("Petal length", value=1.0, step=0.1, key="Petal length") | |
pet_wid = st.number_input("Petal width", value=1.0, step=0.1, key="Petal width") | |
request_body = { | |
"instances": [ | |
[ | |
sep_len, | |
sep_wid, | |
pet_len, | |
pet_wid, | |
], | |
] | |
} | |
# Predict and explain | |
predict_button = st.button("Predict", on_click=predict) | |
if st.session_state.exp is not None: | |
st.write(st.session_state.exp) | |
# predictions = st.session_state.exp["predictions"] | |
# request_log_id = exp["requestLogId"] | |
# prediction_log_id = exp["predictionLogIds"][0] | |
# # exp_df = pd.DataFrame( | |
# # [exp["explanations"][0]["shap_values"]], columns=exp["featureLabels"] | |
# # ) | |
# st.write("Predictions:", predictions) | |
# # Evaluation | |
# if st.session_state.evaluation_submitted is False: | |
# evaluation = st.radio("Do you agree with the prediction?", ("yes", "no")) | |
# if evaluation == "no": | |
# desired_output = # TODO | |
# st.session_state.evaluation_input = { | |
# "result": 1, | |
# "value": {"predictions": [desired_output]}, | |
# } | |
# else: | |
# st.session_state.evaluation_input = {"result": 0} | |
# submit_button = st.button("Submit evaluation", on_click=submit_and_clear, args=(evaluation,)) | |
# else: | |
# st.success("Evaluation submitted successfully.") | |