File size: 3,670 Bytes
f1d3421
 
 
 
0318995
 
 
4879b2a
 
 
 
 
 
 
 
 
 
 
 
 
 
1331b8a
4879b2a
1331b8a
4879b2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a8f0b2
e40a495
 
 
 
 
 
6a8f0b2
4879b2a
 
 
 
528220d
4879b2a
 
 
 
 
 
 
1331b8a
e40a495
6a8f0b2
 
e40a495
5031463
528220d
6a8f0b2
1331b8a
6a8f0b2
1331b8a
 
 
4879b2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
os.system("pip uninstall -y gradio")
os.system("pip install gradio==3.50.2")

from huggingface_hub import InferenceClient
import gradio as gr

"""
Chat engine.

TODOs:
- Better prompts.
- Output reader / parser.
- Agents for evaluation and task planning / splitting.
    * Haystack for orchestration
- Tools for agents
    * Haystack for orchestration
- 

"""

selected_model = "mistralai/Mixtral-8x7B-Instruct-v0.1"

client = InferenceClient(selected_model)

def query_submit(user_message, history):
    return "", history + [[user_message, None]]
  
def format_prompt(query, history, lookback):
    prompt = "Responses should be no more than 100 words long.\n"
    
    for previous_query, prevous_completion in history[-lookback:]:
        prompt += f"<s>[INST] {previous_query} [/INST] {prevous_completion}</s> "
    
    prompt += f"[INST] {query} [/INST]"
  
    return prompt

def query_completion(
    query,
    history,
    lookback = 3,
    max_new_tokens = 256,
):

    generateKwargs = dict(
        max_new_tokens = max_new_tokens,
        seed = 1337,
    )

    formatted_query = format_prompt(query, history, lookback)
    
    stream = client.text_generation(
        formatted_query,
        **generateKwargs,
        stream = True,
        details = True,
        return_full_text = False
    )
    
    history[-1][1] = ""
    
    for response in stream:
        history[-1][1] += response.token.text
        yield history

        
"""
Chat UI using Gradio Blocks.

Blocks preferred for lower-level "atomic" layout control and state management.

TODOs:
- State management for dynamic components update.
- Add scratpad readout to right of chat log.
    * Placeholder added for now.
- Add functionality to retry button.
    * Placeholder added for now.
- Add dropdown for model selection.
- Add textbox for HF model selection.
    
"""

with gr.Blocks() as chatUI:
    with gr.Row():
        chatOutput = gr.Chatbot(
            bubble_full_width = False,
            scale = 2
        )
        agentWhiteBoard = gr.Markdown(scale = 1)
            
    with gr.Row():
        modelSelect = gr.Dropdown(
            label = "Model selection:",
        )
        fileUpload = gr.File(
            height = 50,
        )
        
    with gr.Row():
        queryInput = gr.Textbox(
            placeholder = "Please enter you question or request here...",
            show_label = False,
            scale = 4,
        )
        submitButton = gr.Button("Submit", scale = 1)
        
    with gr.Row():
        retry = gr.Button("Retry (null)")
        clear = gr.ClearButton([queryInput, chatOutput])
    
    with gr.Row():
        with gr.Accordion(label = "Expand for edit system prompt:"):
            systemPrompt = gr.Textbox(
                value = "System prompt here (null)",
                show_label = False,
                lines = 4,
                scale = 4,
        )
    
    with gr.Row():
        footer = gr.HTML("<div class='footer'>" + selected_model + "</div>"),
    
        # gr.State()
    
    queryInput.submit(
        fn = query_submit,
        inputs = [queryInput, chatOutput],
        outputs = [queryInput, chatOutput],
        queue = False,
    ).then(
        fn = query_completion,
        inputs = [queryInput, chatOutput],
        outputs = [chatOutput],
    )
    
    submitButton.click(
        fn = query_submit,
        inputs = [queryInput, chatOutput],
        outputs = [queryInput, chatOutput],
        queue = False,
    ).then(
        fn = query_completion,
        inputs = [queryInput, chatOutput],
        outputs = [chatOutput],
    )

chatUI.queue()
chatUI.launch(show_api = False)