Spaces:
Running
Running
File size: 4,492 Bytes
d65e306 c0c37aa d65e306 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
"""Custom actions used within a dashboard."""
import base64
import io
import logging
import black
import pandas as pd
from _utils import check_file_extension
from dash.exceptions import PreventUpdate
from langchain_openai import ChatOpenAI
from plotly import graph_objects as go
from vizro.models.types import capture
from vizro_ai import VizroAI
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) # TODO: remove manual setting and make centrally controlled
SUPPORTED_VENDORS = {"OpenAI": ChatOpenAI}
def get_vizro_ai_plot(user_prompt, df, model, api_key, api_base, vendor_input): # noqa: PLR0913
"""VizroAi plot configuration."""
vendor = SUPPORTED_VENDORS[vendor_input]
llm = vendor(model_name=model, openai_api_key=api_key, openai_api_base=api_base)
vizro_ai = VizroAI(model=llm)
ai_outputs = vizro_ai.plot(df, user_prompt, explain=False, return_elements=True)
return ai_outputs
@capture("action")
def run_vizro_ai(user_prompt, n_clicks, data, model, api_key, api_base, vendor_input): # noqa: PLR0913
"""Gets the AI response and adds it to the text window."""
def create_response(ai_response, figure, user_prompt, filename):
plotly_fig = figure.to_json()
return (
ai_response,
figure,
{"ai_response": ai_response, "figure": plotly_fig, "prompt": user_prompt, "filename": filename},
)
if not n_clicks:
raise PreventUpdate
if not data:
ai_response = "Please upload data to proceed!"
figure = go.Figure()
return create_response(ai_response, figure, user_prompt, None)
if not api_key:
ai_response = "API key not found. Make sure you enter your API key!"
figure = go.Figure()
return create_response(ai_response, figure, user_prompt, data["filename"])
if api_key.startswith('"'):
ai_response = "Make sure you enter your API key without quotes!"
figure = go.Figure()
return create_response(ai_response, figure, user_prompt, data["filename"])
if api_base is not None and api_base.startswith('"'):
ai_response = "Make sure you enter your API base without quotes!"
figure = go.Figure()
return create_response(ai_response, figure, user_prompt, data["filename"])
try:
logger.info("Attempting chart code.")
df = pd.DataFrame(data["data"])
ai_outputs = get_vizro_ai_plot(
user_prompt=user_prompt,
df=df,
model=model,
api_key=api_key,
api_base=api_base,
vendor_input=vendor_input,
)
ai_code = ai_outputs.code
figure = ai_outputs.figure
formatted_code = black.format_str(ai_code, mode=black.Mode(line_length=100))
ai_response = "\n".join(["```python", formatted_code, "```"])
logger.info("Successful query produced.")
return create_response(ai_response, figure, user_prompt, data["filename"])
except Exception as exc:
logger.debug(exc)
logger.info("Chart creation failed.")
ai_response = f"Sorry, I can't do that. Following Error occurred: {exc}"
figure = go.Figure()
return create_response(ai_response, figure, user_prompt, data["filename"])
@capture("action")
def data_upload_action(contents, filename):
"""Custom data upload action."""
if not contents:
raise PreventUpdate
if not check_file_extension(filename=filename):
return {"error_message": "Unsupported file extension.. Make sure to upload either csv or an excel file."}
content_type, content_string = contents.split(",")
try:
decoded = base64.b64decode(content_string)
if filename.endswith(".csv"):
# Handle CSV file
df = pd.read_csv(io.StringIO(decoded.decode("utf-8")))
else:
# Handle Excel file
df = pd.read_excel(io.BytesIO(decoded))
data = df.to_dict("records")
return {"data": data, "filename": filename}
except Exception as e:
logger.debug(e)
return {"error_message": "There was an error processing this file."}
@capture("action")
def display_filename(data):
"""Custom action to display uploaded filename."""
if data is None:
raise PreventUpdate
display_message = data.get("filename") or data.get("error_message")
return f"Uploaded file name: '{display_message}'" if "filename" in data else display_message
|