File size: 1,934 Bytes
b7a466c
 
 
 
 
5db63ca
b7a466c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4922c3e
b7a466c
 
 
 
 
 
19a4e81
63716c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
from openai import OpenAI
import gradio as gr


client = OpenAI(api_key=os.environ.get('OPENAI_API_KEY'))

class Conversation:
  def __init__(self, prompt, num_of_round):
    self.prompt = prompt
    self.num_of_round = num_of_round
    self.messages = []
    self.messages.append({"role": "system", "content": self.prompt})
  
  def ask(self, question):
    try:
      self.messages.append({"role": "user", "content": question})
      response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=self.messages,
        # temperature=0.5,
        max_tokens=2048,
        # top_p=1,
      )
    
    except Exception as e:
      print(e)
      return e

    message = response.choices[0].message.content
    self.messages.append({"role": "assistant", "content": message})

    if len(self.messages) > self.num_of_round*2 + 1:
      # del self.messages[1:3] //Remove the first round conversation left.
      print(self.num_of_round)
    
    return message

prompt = """你叫赛文奥特曼,工作是陪伴三岁到七岁的儿童成长,以朋友聊天的方式解答他们在生活和学习中遇到的各种困惑和问题。你的回答需要满⾜以下要求:
1. 你的回答必须是中⽂
2. 回答限制在100个字以内"""
conv = Conversation(prompt, 100)

def answer(question, history=[]):
  history.append(question)
  response = conv.ask(question)
  history.append(response)
  responses = [(u,b) for u,b in zip(history[::2], history[1::2])]
  return responses, history

with gr.Blocks(css="#chatbot{height:300px} .overflow-y-auto{height:500px}") as demo:
  chatbot = gr.Chatbot(elem_id="chatbot", label="儿童陪伴机器人")
  state = gr.State([])
  
  with gr.Row():
    txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter")
    txt.submit(answer, [txt, state], [chatbot, state])

# demo.launch(share=True, auth=("mipa", "1234"))
demo.launch(share=True)