File size: 4,492 Bytes
d65e306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0c37aa
d65e306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""Custom actions used within a dashboard."""

import base64
import io
import logging

import black
import pandas as pd
from _utils import check_file_extension
from dash.exceptions import PreventUpdate
from langchain_openai import ChatOpenAI
from plotly import graph_objects as go
from vizro.models.types import capture
from vizro_ai import VizroAI

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)  # TODO: remove manual setting and make centrally controlled

SUPPORTED_VENDORS = {"OpenAI": ChatOpenAI}


def get_vizro_ai_plot(user_prompt, df, model, api_key, api_base, vendor_input):  # noqa: PLR0913
    """VizroAi plot configuration."""
    vendor = SUPPORTED_VENDORS[vendor_input]
    llm = vendor(model_name=model, openai_api_key=api_key, openai_api_base=api_base)
    vizro_ai = VizroAI(model=llm)
    ai_outputs = vizro_ai.plot(df, user_prompt, explain=False, return_elements=True)

    return ai_outputs


@capture("action")
def run_vizro_ai(user_prompt, n_clicks, data, model, api_key, api_base, vendor_input):  # noqa: PLR0913
    """Gets the AI response and adds it to the text window."""

    def create_response(ai_response, figure, user_prompt, filename):
        plotly_fig = figure.to_json()
        return (
            ai_response,
            figure,
            {"ai_response": ai_response, "figure": plotly_fig, "prompt": user_prompt, "filename": filename},
        )

    if not n_clicks:
        raise PreventUpdate

    if not data:
        ai_response = "Please upload data to proceed!"
        figure = go.Figure()
        return create_response(ai_response, figure, user_prompt, None)

    if not api_key:
        ai_response = "API key not found. Make sure you enter your API key!"
        figure = go.Figure()
        return create_response(ai_response, figure, user_prompt, data["filename"])

    if api_key.startswith('"'):
        ai_response = "Make sure you enter your API key without quotes!"
        figure = go.Figure()
        return create_response(ai_response, figure, user_prompt, data["filename"])

    if api_base is not None and api_base.startswith('"'):
        ai_response = "Make sure you enter your API base without quotes!"
        figure = go.Figure()
        return create_response(ai_response, figure, user_prompt, data["filename"])

    try:
        logger.info("Attempting chart code.")
        df = pd.DataFrame(data["data"])
        ai_outputs = get_vizro_ai_plot(
            user_prompt=user_prompt,
            df=df,
            model=model,
            api_key=api_key,
            api_base=api_base,
            vendor_input=vendor_input,
        )
        ai_code = ai_outputs.code
        figure = ai_outputs.figure
        formatted_code = black.format_str(ai_code, mode=black.Mode(line_length=100))

        ai_response = "\n".join(["```python", formatted_code, "```"])
        logger.info("Successful query produced.")
        return create_response(ai_response, figure, user_prompt, data["filename"])

    except Exception as exc:
        logger.debug(exc)
        logger.info("Chart creation failed.")
        ai_response = f"Sorry, I can't do that. Following Error occurred: {exc}"
        figure = go.Figure()
        return create_response(ai_response, figure, user_prompt, data["filename"])


@capture("action")
def data_upload_action(contents, filename):
    """Custom data upload action."""
    if not contents:
        raise PreventUpdate

    if not check_file_extension(filename=filename):
        return {"error_message": "Unsupported file extension.. Make sure to upload either csv or an excel file."}

    content_type, content_string = contents.split(",")

    try:
        decoded = base64.b64decode(content_string)
        if filename.endswith(".csv"):
            # Handle CSV file
            df = pd.read_csv(io.StringIO(decoded.decode("utf-8")))
        else:
            # Handle Excel file
            df = pd.read_excel(io.BytesIO(decoded))

        data = df.to_dict("records")
        return {"data": data, "filename": filename}

    except Exception as e:
        logger.debug(e)
        return {"error_message": "There was an error processing this file."}


@capture("action")
def display_filename(data):
    """Custom action to display uploaded filename."""
    if data is None:
        raise PreventUpdate

    display_message = data.get("filename") or data.get("error_message")
    return f"Uploaded file name: '{display_message}'" if "filename" in data else display_message