Spaces:
Sleeping
Sleeping
File size: 6,823 Bytes
7acf6af 7db6cdc 7acf6af d20f3c3 7db6cdc 6cf1fe6 7db6cdc 7acf6af 7db6cdc 7acf6af 7db6cdc 7acf6af 7db6cdc 7acf6af d20f3c3 7db6cdc 7acf6af 7db6cdc d20f3c3 7db6cdc d20f3c3 7db6cdc 7acf6af 7db6cdc 7668d5b 7db6cdc d20f3c3 7db6cdc d20f3c3 7db6cdc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import gradio as gr
import os
import sys
import json
import requests
import random
API_URL = ""
LA_SERVICE_URL_v1 = "http://bore.testsprep.online:8082/v1/chat/completions"
LA_SERVICE_URL_v2 = "http://bore.testsprep.online:8083/v1/chat/completions"
LA_SERVICE_URL_v3 = "http://bore.testsprep.online:8084/v1/chat/completions"
LA_SERVICE_URL_v4 = "http://bore.testsprep.online:8085/v1/chat/completions"
MODEL2SERVICE = {
'LA-llama-3.1-7b-16k-sft-awq': LA_SERVICE_URL_v1,
'LA-storm-llama-3.1-7b-16k-sft-awq': LA_SERVICE_URL_v2,
'LA-cohere-aya-expanse-8b-16k-sft-awq': LA_SERVICE_URL_v3,
'LA-qwen2.5-7b-16k-sft-awq': LA_SERVICE_URL_v4,
}
DISABLED = os.getenv("DISABLED") == 'True'
BEARER_TOKEN = "Prep@123"
NUM_THREADS = 16
SYSTEM_PROMPT = """Bạn là Trợ lý gia sư AI dạy ngôn ngữ Tiếng Anh, tên là Teacher Bee AI. Bạn được xây dựng bởi Prep Education để hướng dẫn học viên làm bài tập trên nền tảng Prepedu.com.
Bạn là một trợ lý thân thiện, tính cách tốt bụng và supportive. Giả sử bạn đang hướng dẫn, giải thích và trả lời câu hỏi cho một đứa trẻ 12 tuổi hoặc ở trình độ ngôn ngữ không cao hơn trình độ của người học."""
def exception_handler(exception_type, exception, traceback):
print("%s: %s" % (exception_type.__name__, exception))
sys.excepthook = exception_handler
sys.tracebacklimit = 0
def predict(model_selector, system_prompt, inputs, top_p, temperature, max_tokens, chat_counter, chatbot, history, request: gr.Request):
messages = [{"role": "system", "content": system_prompt}] if system_prompt else []
headers = {
"accept": "application/json",
"Authorization": "Bearer Prep@123",
"Content-Type": "application/json"
}
print("\n\n")
print("="*100)
print(f"chat_counter: {chat_counter}")
print(f"history: {history}")
if chat_counter != 0:
for i, data in enumerate(history):
if i % 2 == 0:
role = 'user'
else:
role = 'assistant'
messages.append({"role": role, "content": data})
messages.append({"role": "user", "content": inputs})
print(f"messages: {messages}")
payload = {
"model": "LA-SFT",
"messages": messages,
"do_sample": True,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"n": 1,
"stream": True,
"presence_penalty": 0,
"frequency_penalty": 0,
}
else:
messages.append({"role": "user", "content": inputs})
payload = {
"model": "LA-SFT",
"messages": messages,
"do_sample": True,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"n": 1,
"stream": True,
"presence_penalty": 0,
"frequency_penalty": 0,
}
chat_counter += 1
history.append(inputs)
token_counter = 0
partial_words = ""
counter = 0
try:
if payload:
print(f"\n>>> Payload: {payload}")
# Gọi API với stream=True
response = requests.post(MODEL2SERVICE[model_selector], headers=headers, json=payload, stream=True)
for chunk in response.iter_lines():
if counter == 0:
counter += 1
continue
if chunk.decode():
chunk = chunk.decode()
if len(chunk) > 12 and "content" in json.loads(chunk[6:])['choices'][0]['delta']:
partial_words += json.loads(chunk[6:])['choices'][0]["delta"]["content"]
if token_counter == 0:
history.append(" " + partial_words)
else:
history[-1] = partial_words
token_counter += 1
yield [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)], history, chat_counter, response, gr.update(interactive=False), gr.update(interactive=False)
except Exception as e:
print(f'error found: {e}')
yield [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)], history, chat_counter, response, gr.update(interactive=True), gr.update(interactive=True)
def reset_textbox():
return gr.update(value='', interactive=False), gr.update(interactive=False)
title = """<h1 align="center">Learning Assistant In-house Model</h1>"""
theme = gr.themes.Default(primary_hue="green")
with gr.Blocks(
css="""#col_container { margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""",
theme=theme) as demo:
gr.HTML(title)
with gr.Column(elem_id="col_container", visible=True) as main_block:
model_selector = gr.Dropdown(choices=list(MODEL2SERVICE.keys()), label="Select Model", value=list(MODEL2SERVICE.keys())[0])
chatbot = gr.Chatbot(elem_id='chatbot')
inputs = gr.Textbox(placeholder="Hi there!", label="Type an input and press Enter")
state = gr.State([])
with gr.Row():
with gr.Column(scale=7):
b1 = gr.Button(visible=True)
with gr.Column(scale=3):
server_status_code = gr.Textbox(label="Status code from PREP server")
system_prompt = gr.Textbox(placeholder="Enter system prompt here", label="System Prompt", value=SYSTEM_PROMPT)
with gr.Accordion("Parameters", open=False):
top_p = gr.Slider(minimum=0, maximum=1.0, value=0.9, step=0.05, interactive=True,
label="Top-p (nucleus sampling)")
temperature = gr.Slider(minimum=0, maximum=5.0, value=0.1, step=0.1, interactive=True, label="Temperature")
max_tokens = gr.Slider(minimum=0, maximum=16_000, value=4096, step=0.1, interactive=True, label="Max tokens")
chat_counter = gr.Number(value=0, visible=False, precision=0)
inputs.submit(reset_textbox, [], [inputs, b1], queue=False)
inputs.submit(predict, [model_selector, system_prompt, inputs, top_p, temperature, max_tokens, chat_counter, chatbot, state],
[chatbot, state, chat_counter, server_status_code, inputs, b1])
b1.click(reset_textbox, [], [inputs, b1], queue=False)
b1.click(predict, [model_selector, system_prompt, inputs, top_p, temperature, max_tokens, chat_counter, chatbot, state],
[chatbot, state, chat_counter, server_status_code, inputs, b1])
demo.queue(max_size=10, default_concurrency_limit=NUM_THREADS, api_open=False).launch(share=False)
|