|
import os |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
model_name = "meta-llama/Llama-2-70b-chat-hf" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
class Conversation: |
|
def __init__(self, prompt, round): |
|
self.prompt = prompt |
|
self.round = round |
|
self.messages = [] |
|
self.messages.append({"role": "system", "content": self.prompt}) |
|
|
|
def ask(self, question): |
|
try: |
|
self.messages.append({"role": "user", "content": question}) |
|
input_text = self._build_message(self.messages) |
|
encoded_input = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt') |
|
output = model.generate(encoded_input, max_length=200, temperature=0.5) |
|
message = tokenizer.decode(output[:, encoded_input.shape[-1]:][0], skip_special_tokens=True) |
|
except Exception as e: |
|
print(e) |
|
return e |
|
|
|
self.messages.append({"role": "assistant", "content": message}) |
|
|
|
if len(self.messages) > self.round * 2 + 1: |
|
text = self._build_message(self.messages) |
|
self.messages = [] |
|
self.messages.append({"role": "system", "content": text}) |
|
return message |
|
|
|
def _build_message(self, messages): |
|
text = "" |
|
for message in messages: |
|
if message["role"] == "user": |
|
text += "User : " + message["content"] + "\n\n" |
|
if message["role"] == "assistant": |
|
text += "Assistant : " + message["content"] + "\n\n" |
|
return text |
|
|
|
prompt = """你是一个大数据和AI领域的专家,用中文回答大数据和AI的相关问题。你的回答需要满足以下要求: |
|
1. 你的回答必须是中文 |
|
2. 回答限制在200个字以内 |
|
3. 拒绝回答违反社会道德和法律的问题""" |
|
|
|
conv = Conversation(prompt, 3) |
|
|
|
def answer(question, history=[]): |
|
history.append(question) |
|
message = conv.ask(question) |
|
history.append(message) |
|
responses = [(u,b) for u,b in zip(history[::2], history[1::2])] |
|
print(responses) |
|
return responses, history |
|
|
|
with gr.Blocks(css="#chatbot{height:300px} .overflow-y-auto{height:500px}") as rxbot: |
|
chatbot = gr.Chatbot(elem_id="chatbot") |
|
state = gr.State([]) |
|
|
|
with gr.Row(): |
|
txt = gr.Textbox(show_label=False, placeholder="请输入你的问题").style(container=False) |
|
|
|
txt.submit(answer, [txt, state], [chatbot, state]) |
|
|
|
|
|
rxbot.launch() |