File size: 3,761 Bytes
fe1089d
2492536
fe1089d
 
 
 
 
 
2492536
fe1089d
 
 
2492536
fe1089d
 
 
 
 
 
 
2492536
 
fe1089d
 
 
 
 
2492536
ba1dc89
2492536
fe1089d
 
2492536
ba1dc89
58a02af
fe1089d
 
 
 
 
 
2492536
fe1089d
 
2492536
c28c597
 
 
 
 
 
 
fe1089d
2492536
fe1089d
 
 
 
 
 
 
 
 
 
f5ebee7
fe1089d
 
 
 
d2116db
fe1089d
 
 
f5ebee7
fe1089d
 
 
 
 
 
 
 
 
2492536
fe1089d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2492536
f5ebee7
d2116db
fe1089d
 
 
 
f5ebee7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# controller for the application that calls the model and explanation functions
# returns the updated conversation history and extra elements

# external imports
import gradio as gr

# internal imports
from model import godel
from explanation import interpret_shap as shap_int, visualize as viz


# main interference function that that calls chat functions depending on selections
# is getting called on every chat submit
def interference(
    prompt: str,
    history: list,
    knowledge: str,
    system_prompt: str,
    xai_selection: str,
):
    # if no proper system prompt is given, use a default one
    if system_prompt in ('', ' '):
        system_prompt = """
            You are a helpful, respectful and honest assistant.
            Always answer as helpfully as possible, while being safe.
        """

    # if a XAI approach is selected, grab the XAI module instance
    if xai_selection in ("SHAP", "Attention"):
        # matching selection
        match xai_selection.lower():
            case "shap":
                xai = shap_int
            case "attention":
                xai = viz
            case _:
                # use Gradio warning to display error message
                gr.Warning(f"""
                    There was an error in the selected XAI Approach.
                    It is "{xai_selection}"
                    """)
                # raise runtime exception
                raise RuntimeError("There was an error in the selected XAI approach.")

        # call the explained chat function with the model instance
        prompt_output, history_output, xai_graphic, xai_markup = explained_chat(
            model=godel,
            xai=xai,
            message=prompt,
            history=history,
            system_prompt=system_prompt,
            knowledge=knowledge,
        )
    # if no XAI approach is selected call the vanilla chat function
    else:
        # call the vanilla chat function
        prompt_output, history_output = vanilla_chat(
            model=godel,
            message=prompt,
            history=history,
            system_prompt=system_prompt,
            knowledge=knowledge,
        )
        # set XAI outputs to disclaimer html/none
        xai_graphic, xai_markup = (
            """
            <div style="text-align: center"><h4>Without Selected XAI Approach,
            no graphic will be displayed</h4></div>
            """,
            [("", "")],
        )

    # return the outputs
    return prompt_output, history_output, xai_graphic, xai_markup


# simple chat function that calls the model
# formats prompts, calls for an answer and returns updated conversation history
def vanilla_chat(
    model, message: str, history: list, system_prompt: str, knowledge: str = ""
):
    # formatting the prompt using the model's format_prompt function
    prompt = model.format_prompt(message, history, system_prompt, knowledge)

    # generating an answer using the model's respond function
    answer = model.respond(prompt)

    # updating the chat history with the new answer
    history.append((message, answer))
    # returning the updated history
    return "", history


def explained_chat(
    model, xai, message: str, history: list, system_prompt: str, knowledge: str = ""
):
    # formatting the prompt using the model's format_prompt function
    prompt = model.format_prompt(message, history, system_prompt, knowledge)

    # generating an answer using the methods chat function
    answer, xai_graphic, xai_markup = xai.chat_explained(model, prompt)

    # updating the chat history with the new answer
    history.append((message, answer))

    # returning the updated history, xai graphic and xai plot elements
    return "", history, xai_graphic, xai_markup