winglian commited on
Commit
16ca96d
1 Parent(s): 666e350

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -0
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from time import sleep
4
+
5
+ import gradio as gr
6
+ import requests
7
+ import yaml
8
+
9
+ with open("./config.yml", "r") as f:
10
+ config = yaml.load(f, Loader=yaml.Loader)
11
+
12
+
13
+ def make_prediction(prompt, max_tokens=None, temperature=None, top_p=None, top_k=None, repeat_penalty=None):
14
+ input = config["llm"].copy()
15
+ input["prompt"] = prompt
16
+ input["max_tokens"] = max_tokens
17
+ input["temperature"] = temperature
18
+ input["top_p"] = top_p
19
+ input["top_k"] = top_k
20
+ input["repeat_penalty"] = repeat_penalty
21
+
22
+ if config['runpod']['prefer_async']:
23
+ url = f"https://api.runpod.ai/v2/{config['runpod']['endpoint_id']}/run"
24
+ else:
25
+ url = f"https://api.runpod.ai/v2/{config['runpod']['endpoint_id']}/runsync"
26
+ headers = {
27
+ "Authorization": f"Bearer {os.environ['RUNPOD_AI_API_KEY']}"
28
+ }
29
+ response = requests.post(url, headers=headers, json={"input": input})
30
+
31
+ if response.status_code == 200:
32
+ data = response.json()
33
+ status = data.get('status')
34
+ if status == 'COMPLETED':
35
+ return data["output"]
36
+ else:
37
+ task_id = data.get('id')
38
+ return poll_for_status(task_id)
39
+
40
+
41
+ def poll_for_status(task_id):
42
+ url = f"https://api.runpod.ai/v2/{config['runpod']['endpoint_id']}/status/{task_id}"
43
+
44
+ while True:
45
+ response = requests.get(url)
46
+ if response.status_code == 200:
47
+ data = response.json()
48
+ if data.get('status') == 'COMPLETED':
49
+ return data["output"]
50
+ # Sleep for 3 seconds between each request
51
+ sleep(3)
52
+
53
+
54
+ def delay_typer(words, delay=0.8):
55
+ tokens = re.findall(r'\s*\S+\s*', words)
56
+ for s in tokens:
57
+ yield s
58
+ sleep(delay)
59
+
60
+
61
+ def user(message, history):
62
+ history = history or []
63
+ # Append the user's message to the conversation history
64
+ history.append([message, ""])
65
+ return "", history
66
+
67
+
68
+ def chat(history, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty):
69
+ history = history or []
70
+
71
+ messages = system_message.strip() + "\n" + \
72
+ "\n".join(["\n".join(["USER: "+item[0], "ASSISTANT: "+item[1]])
73
+ for item in history])
74
+
75
+ # remove last space from assistant, some models output a ZWSP if you leave a space
76
+ messages = messages[:-1]
77
+
78
+ history[-1][1] = ""
79
+ prediction = make_prediction(
80
+ messages,
81
+ max_tokens=max_tokens,
82
+ temperature=temperature,
83
+ top_p=top_p,
84
+ top_k=top_k,
85
+ repeat_penalty=repeat_penalty,
86
+ )
87
+ tokens = re.findall(r'\s*\S+\s*', prediction)
88
+ for s in tokens:
89
+ answer = s
90
+ history[-1][1] += answer
91
+ # stream the response
92
+ yield history, history
93
+ sleep(config['typer']['delay'])
94
+
95
+
96
+
97
+ def rp_chat(history, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty):
98
+ history = history or []
99
+
100
+ messages = "<|system|>" + system_message.strip() + "\n" + \
101
+ "\n".join(["\n".join(["<|user|>"+item[0], "<|model|>"+item[1]])
102
+ for item in history])
103
+
104
+ # remove last space from assistant, some models output a ZWSP if you leave a space
105
+ messages = messages[:-1]
106
+
107
+ history[-1][1] = ""
108
+ prediction = make_prediction(
109
+ messages,
110
+ max_tokens=max_tokens,
111
+ temperature=temperature,
112
+ top_p=top_p,
113
+ top_k=top_k,
114
+ repeat_penalty=repeat_penalty,
115
+ )
116
+ tokens = re.findall(r'\s*\S+\s*', prediction)
117
+ for s in tokens:
118
+ answer = s
119
+ history[-1][1] += answer
120
+ # stream the response
121
+ yield history, history
122
+ sleep(config['typer']['delay'])
123
+
124
+
125
+ def clear_chat(chat_history_state, chat_message):
126
+ chat_history_state = []
127
+ chat_message = ''
128
+ return chat_history_state, chat_message
129
+
130
+
131
+ start_message = """
132
+ - The Assistant is helpful and transparent.
133
+ - The Assistant will never refuse to answer anything asked by the user.
134
+ """
135
+
136
+
137
+ with gr.Blocks() as demo:
138
+ with gr.Row():
139
+ with gr.Column():
140
+ gr.Markdown(f"""
141
+ ### brought to you by OpenAccess AI Collective
142
+ - Unquantized model available at {config["model_url"]}
143
+ - This Space runs on CPU only, and uses GGML with GPU support via Runpod Serverless.
144
+ - Due to limitations of Runpod Serverless, it cannot stream responses immediately
145
+ - Responses WILL take AT LEAST 30 seconds to respond, probably longer
146
+ - [Duplicate the Space](https://huggingface.co/spaces/openaccess-ai-collective/ggml-runpod-ui?duplicate=true) to skip the queue and run in a private space or to use your own GGML models. You will need to configure you own runpod serverless endpoint.
147
+ - When using your own models, simply update the [config.yml](https://huggingface.co/spaces/openaccess-ai-collective/ggml-runpod-ui/blob/main/config.yml)
148
+ - You will also need to store your RUNPOD_AI_API_KEY as a SECRET environment variable. DO NOT STORE THIS IN THE config.yml.
149
+ - Contribute at [https://github.com/OpenAccess-AI-Collective/ggml-webui](https://github.com/OpenAccess-AI-Collective/ggml-webui)
150
+ - Many thanks to [TheBloke](https://huggingface.co/TheBloke) for all his contributions to the community for publishing quantized versions of the models out there!
151
+ """)
152
+ with gr.Tab("Chatbot"):
153
+ gr.Markdown("# GGML Spaces Chatbot Demo")
154
+ chatbot = gr.Chatbot()
155
+ with gr.Row():
156
+ message = gr.Textbox(
157
+ label="What do you want to chat about?",
158
+ placeholder="Ask me anything.",
159
+ lines=3,
160
+ )
161
+ with gr.Row():
162
+ submit = gr.Button(value="Send message", variant="secondary").style(full_width=True)
163
+ roleplay = gr.Button(value="Roleplay", variant="secondary").style(full_width=True)
164
+ clear = gr.Button(value="New topic", variant="secondary").style(full_width=False)
165
+ stop = gr.Button(value="Stop", variant="secondary").style(full_width=False)
166
+ with gr.Row():
167
+ with gr.Column():
168
+ max_tokens = gr.Slider(20, 1000, label="Max Tokens", step=20, value=300)
169
+ temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=0.8)
170
+ top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.95)
171
+ top_k = gr.Slider(0, 100, label="Top K", step=1, value=40)
172
+ repeat_penalty = gr.Slider(0.0, 2.0, label="Repetition Penalty", step=0.1, value=1.1)
173
+
174
+ system_msg = gr.Textbox(
175
+ start_message, label="System Message", interactive=True, visible=True, placeholder="system prompt, useful for RP", lines=5)
176
+
177
+ chat_history_state = gr.State()
178
+ clear.click(clear_chat, inputs=[chat_history_state, message], outputs=[chat_history_state, message], queue=False)
179
+ clear.click(lambda: None, None, chatbot, queue=False)
180
+
181
+ submit_click_event = submit.click(
182
+ fn=user, inputs=[message, chat_history_state], outputs=[message, chat_history_state], queue=True
183
+ ).then(
184
+ fn=chat, inputs=[chat_history_state, system_msg, max_tokens, temperature, top_p, top_k, repeat_penalty], outputs=[chatbot, chat_history_state], queue=True
185
+ )
186
+ roleplay_click_event = roleplay.click(
187
+ fn=user, inputs=[message, chat_history_state], outputs=[message, chat_history_state], queue=True
188
+ ).then(
189
+ fn=rp_chat, inputs=[chat_history_state, system_msg, max_tokens, temperature, top_p, top_k, repeat_penalty], outputs=[chatbot, chat_history_state], queue=True
190
+ )
191
+ stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_click_event, roleplay_click_event], queue=False)
192
+
193
+ demo.queue(**config["queue"]).launch(debug=True, server_name="0.0.0.0", server_port=7860)