phucpx
add passsage templates
d20f3c3
import gradio as gr
import os
import sys
import json
import requests
import random
API_URL = ""
LA_SERVICE_URL_v1 = "http://bore.testsprep.online:8082/v1/chat/completions"
LA_SERVICE_URL_v2 = "http://bore.testsprep.online:8083/v1/chat/completions"
LA_SERVICE_URL_v3 = "http://bore.testsprep.online:8084/v1/chat/completions"
LA_SERVICE_URL_v4 = "http://bore.testsprep.online:8085/v1/chat/completions"
MODEL2SERVICE = {
'LA-llama-3.1-7b-16k-sft-awq': LA_SERVICE_URL_v1,
'LA-storm-llama-3.1-7b-16k-sft-awq': LA_SERVICE_URL_v2,
'LA-cohere-aya-expanse-8b-16k-sft-awq': LA_SERVICE_URL_v3,
'LA-qwen2.5-7b-16k-sft-awq': LA_SERVICE_URL_v4,
}
DISABLED = os.getenv("DISABLED") == 'True'
BEARER_TOKEN = "Prep@123"
NUM_THREADS = 16
SYSTEM_PROMPT = """Bạn là Trợ lý gia sư AI dạy ngôn ngữ Tiếng Anh, tên là Teacher Bee AI. Bạn được xây dựng bởi Prep Education để hướng dẫn học viên làm bài tập trên nền tảng Prepedu.com.
Bạn là một trợ lý thân thiện, tính cách tốt bụng và supportive. Giả sử bạn đang hướng dẫn, giải thích và trả lời câu hỏi cho một đứa trẻ 12 tuổi hoặc ở trình độ ngôn ngữ không cao hơn trình độ của người học."""
def exception_handler(exception_type, exception, traceback):
print("%s: %s" % (exception_type.__name__, exception))
sys.excepthook = exception_handler
sys.tracebacklimit = 0
def predict(model_selector, system_prompt, inputs, top_p, temperature, max_tokens, chat_counter, chatbot, history, request: gr.Request):
messages = [{"role": "system", "content": system_prompt}] if system_prompt else []
headers = {
"accept": "application/json",
"Authorization": "Bearer Prep@123",
"Content-Type": "application/json"
}
print("\n\n")
print("="*100)
print(f"chat_counter: {chat_counter}")
print(f"history: {history}")
if chat_counter != 0:
for i, data in enumerate(history):
if i % 2 == 0:
role = 'user'
else:
role = 'assistant'
messages.append({"role": role, "content": data})
messages.append({"role": "user", "content": inputs})
print(f"messages: {messages}")
payload = {
"model": "LA-SFT",
"messages": messages,
"do_sample": True,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"n": 1,
"stream": True,
"presence_penalty": 0,
"frequency_penalty": 0,
}
else:
messages.append({"role": "user", "content": inputs})
payload = {
"model": "LA-SFT",
"messages": messages,
"do_sample": True,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"n": 1,
"stream": True,
"presence_penalty": 0,
"frequency_penalty": 0,
}
chat_counter += 1
history.append(inputs)
token_counter = 0
partial_words = ""
counter = 0
try:
if payload:
print(f"\n>>> Payload: {payload}")
# Gọi API với stream=True
response = requests.post(MODEL2SERVICE[model_selector], headers=headers, json=payload, stream=True)
for chunk in response.iter_lines():
if counter == 0:
counter += 1
continue
if chunk.decode():
chunk = chunk.decode()
if len(chunk) > 12 and "content" in json.loads(chunk[6:])['choices'][0]['delta']:
partial_words += json.loads(chunk[6:])['choices'][0]["delta"]["content"]
if token_counter == 0:
history.append(" " + partial_words)
else:
history[-1] = partial_words
token_counter += 1
yield [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)], history, chat_counter, response, gr.update(interactive=False), gr.update(interactive=False)
except Exception as e:
print(f'error found: {e}')
yield [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)], history, chat_counter, response, gr.update(interactive=True), gr.update(interactive=True)
def reset_textbox():
return gr.update(value='', interactive=False), gr.update(interactive=False)
title = """<h1 align="center">Learning Assistant In-house Model</h1>"""
theme = gr.themes.Default(primary_hue="green")
with gr.Blocks(
css="""#col_container { margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""",
theme=theme) as demo:
gr.HTML(title)
with gr.Column(elem_id="col_container", visible=True) as main_block:
model_selector = gr.Dropdown(choices=list(MODEL2SERVICE.keys()), label="Select Model", value=list(MODEL2SERVICE.keys())[0])
chatbot = gr.Chatbot(elem_id='chatbot')
inputs = gr.Textbox(placeholder="Hi there!", label="Type an input and press Enter")
state = gr.State([])
with gr.Row():
with gr.Column(scale=7):
b1 = gr.Button(visible=True)
with gr.Column(scale=3):
server_status_code = gr.Textbox(label="Status code from PREP server")
system_prompt = gr.Textbox(placeholder="Enter system prompt here", label="System Prompt", value=SYSTEM_PROMPT)
with gr.Accordion("Parameters", open=False):
top_p = gr.Slider(minimum=0, maximum=1.0, value=0.9, step=0.05, interactive=True,
label="Top-p (nucleus sampling)")
temperature = gr.Slider(minimum=0, maximum=5.0, value=0.1, step=0.1, interactive=True, label="Temperature")
max_tokens = gr.Slider(minimum=0, maximum=16_000, value=4096, step=0.1, interactive=True, label="Max tokens")
chat_counter = gr.Number(value=0, visible=False, precision=0)
inputs.submit(reset_textbox, [], [inputs, b1], queue=False)
inputs.submit(predict, [model_selector, system_prompt, inputs, top_p, temperature, max_tokens, chat_counter, chatbot, state],
[chatbot, state, chat_counter, server_status_code, inputs, b1])
b1.click(reset_textbox, [], [inputs, b1], queue=False)
b1.click(predict, [model_selector, system_prompt, inputs, top_p, temperature, max_tokens, chat_counter, chatbot, state],
[chatbot, state, chat_counter, server_status_code, inputs, b1])
demo.queue(max_size=10, default_concurrency_limit=NUM_THREADS, api_open=False).launch(share=False)