MaziyarPanahi commited on
Commit
9e608cc
·
unverified ·
1 Parent(s): c130254

Add application file

Browse files
Files changed (1) hide show
  1. app.py +170 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ import requests
4
+ import json
5
+ import os
6
+
7
+ MODEL = "gpt-4-0125-preview"
8
+ API_URL = os.getenv("API_URL")
9
+ API_KEY = os.getenv("API_KEY")
10
+
11
+ url = f"{API_URL}/v1/chat/completions"
12
+
13
+ # The headers for the HTTP request
14
+ headers = {
15
+ "accept": "application/json",
16
+ "Content-Type": "application/json",
17
+ "Authorization": f"Bearer {API_KEY}",
18
+ }
19
+
20
+
21
+ def is_valid_json(data):
22
+ try:
23
+ # Attempt to parse the JSON data
24
+ parsed_data = json.loads(data)
25
+ return True, parsed_data
26
+ except ValueError as e:
27
+ # If an error occurs, the JSON is not valid
28
+ return False, str(e)
29
+
30
+
31
+ with gr.Blocks() as demo:
32
+
33
+ chatbot = gr.Chatbot()
34
+ msg = gr.Textbox()
35
+ clear = gr.Button("Clear")
36
+ with gr.Row():
37
+
38
+ with gr.Column(scale=4):
39
+ # Define inputs for additional parameters
40
+ system_prompt_input = gr.Textbox(
41
+ label="System Prompt",
42
+ placeholder="Type system prompt here...",
43
+ value="You are a helpful assistant.",
44
+ )
45
+ temperature_input = gr.Slider(
46
+ label="Temperature", minimum=0.0, maximum=1.0, value=0.9, step=0.01
47
+ )
48
+ max_new_tokens_input = gr.Slider(
49
+ label="Max New Tokens", minimum=0, maximum=1024, value=256, step=1
50
+ )
51
+ top_p_input = gr.Slider(
52
+ label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.01
53
+ )
54
+ repetition_penalty_input = gr.Slider(
55
+ label="Repetition Penalty",
56
+ minimum=1.0,
57
+ maximum=2.0,
58
+ value=1.2,
59
+ step=0.01,
60
+ )
61
+ with gr.Column(scale=1):
62
+ markup = gr.Markdown("## Mistral 7B Instruct v0.2 GGUF")
63
+
64
+ def update_globals(
65
+ system_prompt, temperature, max_new_tokens, top_p, repetition_penalty
66
+ ):
67
+ global global_system_prompt, global_temperature, global_max_new_tokens, global_top_p, global_repetition_penalty
68
+ global_system_prompt = system_prompt
69
+ global_temperature = temperature
70
+ global_max_new_tokens = max_new_tokens
71
+ global_top_p = top_p
72
+ global_repetition_penalty = repetition_penalty
73
+
74
+ def user(user_message, history):
75
+ # print(f"User: {user_message}")
76
+ # print(f"History: {history}")
77
+ return "", history + [[user_message, None]]
78
+
79
+ def bot(
80
+ history, system_prompt, temperature, max_new_tokens, top_p, repetition_penalty
81
+ ):
82
+ print(f"History in bot: {history}")
83
+ print(f"System Prompt: {system_prompt}")
84
+ print(f"Temperature: {temperature}")
85
+ print(f"Max New Tokens: {max_new_tokens}")
86
+ print(f"Top P: {top_p}")
87
+ print(f"Repetition Penalty: {repetition_penalty}")
88
+
89
+ # print(f"History in bot: {history}")
90
+ # [['Capital of France', 'The capital city of France is Paris.'], ['Thansk', 'You are welcome.'], ['What is the capital of France?', '']]
91
+ # convert this to [['Capital of France', 'The capital city of France is Paris.'], ['Thansk', 'You are welcome.'], ['What is the capital of France?', '']] to list of dict of role user and assiatant
92
+ history_messages = [{"content": h[0], "role": "user"} for h in history if h[0]]
93
+ # let's extract the user's question which should be the last touple first element
94
+ # user_question = history[-1][0]
95
+ history[-1][1] = ""
96
+
97
+ history_messages = system_prompt + history_messages
98
+ print(history_messages)
99
+
100
+ data = {
101
+ "messages": history_messages,
102
+ "stream": True,
103
+ "temprature": temperature,
104
+ "top_k": 50,
105
+ "top_p": 0.95,
106
+ "seed": 42,
107
+ "repeat_penalty": repetition_penalty,
108
+ "chat_format": "mistral-instruct",
109
+ "max_tokens": max_new_tokens,
110
+ "response_format": {
111
+ "type": "json_object",
112
+ },
113
+ }
114
+
115
+ # # Making the POST request and streaming the response
116
+ response = requests.post(
117
+ url, headers=headers, data=json.dumps(data), stream=True
118
+ )
119
+ for line in response.iter_lines():
120
+ # Filter out keep-alive new lines
121
+ if line:
122
+ data = line.decode("utf-8").lstrip("data: ")
123
+ # Check if the examples are valid
124
+ valid_check = is_valid_json(data)
125
+ if valid_check[0]:
126
+ try:
127
+ # Attempt to parse the JSON dataa
128
+ # json_data = json.loads(data)
129
+ json_data = valid_check[1]
130
+
131
+ delta_content = (
132
+ json_data.get("choices", [{}])[0]
133
+ .get("delta", {})
134
+ .get("content", "")
135
+ )
136
+
137
+ if delta_content: # Ensure there's content to print
138
+ # print(f"Bot: {delta_content}")
139
+ history[-1][1] += delta_content
140
+ # print(history)
141
+ time.sleep(0.05)
142
+ yield history
143
+ except json.JSONDecodeError as e:
144
+ print(
145
+ f"Error decoding JSON: {e} date: {data}"
146
+ ) # print(delta_content, flush=True, end="")
147
+
148
+ # print(json_data['choices'][0])
149
+
150
+ msg.submit(
151
+ user, [msg, chatbot], [msg, chatbot], queue=False, concurrency_limit=5
152
+ ).then(
153
+ bot,
154
+ inputs=[
155
+ chatbot,
156
+ system_prompt_input,
157
+ temperature_input,
158
+ max_new_tokens_input,
159
+ top_p_input,
160
+ repetition_penalty_input,
161
+ ],
162
+ outputs=chatbot,
163
+ )
164
+
165
+ clear.click(lambda: None, None, chatbot, queue=False)
166
+
167
+
168
+ demo.queue()
169
+ if __name__ == "__main__":
170
+ demo.launch(show_api=False, share=False)