File size: 2,413 Bytes
f745223
504b6c8
e8fb838
1ffd977
 
13a089e
e8fb838
 
f745223
cbcb343
f745223
 
 
d8a82cd
52c453e
f745223
13a089e
a9db698
13a089e
19cbba1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13a089e
 
 
 
 
 
 
 
 
a9db698
 
13a089e
1ffd977
 
a9db698
f57923a
a9db698
 
 
1ffd977
19cbba1
 
 
cf7aa4d
9b0bdb7
a9db698
9b0bdb7
a9db698
9b0bdb7
 
19cbba1
 
 
 
 
 
 
 
1ffd977
2334dc1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import time
import torch
import gradio as gr

from strings import TITLE, ABSTRACT 
from gen import get_pretrained_models, get_output, setup_model_parallel

os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "50505"

local_rank, world_size = setup_model_parallel()
generator = get_pretrained_models("7B", "tokenizer", local_rank, world_size)

history = []
simple_history = []

def chat(user_input, top_p, temperature, max_gen_len):
    bot_response = get_output(
        generator=generator, 
        prompt=user_input,
        max_gen_len=max_gen_len,
        temperature=temperature,
        top_p=top_p)

    # remove the first phrase identical to user prompt
    bot_response = bot_response[0][len(user_input):]
    # trip the last phrase
    try:
        bot_response = bot_response[:bot_response.rfind(".")]
    except:
        pass

    history.append({
        "role": "user",
        "content": user_input
    })
    history.append({
        "role": "system",
        "content": bot_response
    })    

    simple_history.append((user_input, None))
    
    response = ""
    for word in bot_response.split(" "):
        time.sleep(0.1)
        response += word + " "
        current_pair = (user_input, response)
        simple_history[-1] = current_pair
        yield simple_history

def reset_textbox():
    return gr.update(value='')

with gr.Blocks(css = """#col_container {width: 95%; margin-left: auto; margin-right: auto;}
                #chatbot {height: 400px; overflow: auto;}""") as demo:
    
    with gr.Column(elem_id='col_container'):
        gr.Markdown(f"## {TITLE}\n\n\n\n{ABSTRACT}")
        chatbot = gr.Chatbot(elem_id='chatbot')
        textbox = gr.Textbox(placeholder="Enter a prompt")

        with gr.Accordion("Parameters", open=False):
            max_gen_len = gr.Slider(minimum=20, maximum=512, value=256, step=1, interactive=True, label="Max Genenration Length",)
            top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p (nucleus sampling)",)
            temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Temperature",)
        
    textbox.submit(chat, [textbox, top_p, temperature, max_gen_len], chatbot)
    textbox.submit(reset_textbox, [], [textbox])

demo.queue(api_open=False).launch()