File size: 1,808 Bytes
53f8a32
 
f4fe081
da8a172
d487976
 
 
da8a172
dc422ae
d487976
 
 
e71462a
d487976
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53f8a32
 
 
 
d487976
 
 
 
 
 
 
 
53f8a32
d487976
53f8a32
4263bcd
53f8a32
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import logging
from typing import cast
from threading import Lock
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

from conversation import get_default_conv_template
import gradio as gr

talkers = {
    "m3b": {
        "tokenizer": AutoTokenizer.from_pretrained("GeneZC/MiniChat-3B", use_fast=False),
        "model": AutoModelForCausalLM.from_pretrained("GeneZC/MiniChat-3B", device_map="auto", low_cpu_mem_usage=True),
        "conv": get_default_conv_template("minichat")
    }
}

def m3b_talk(text):
    m3bconv = talkers["m3b"]["conv"]
    m3bconv.append_message(m3bconv.roles[0], text)
    m3bconv.append_message(m3bconv.roles[1], None)
    input_ids = talkers["m3b"]["tokenizer"]([text]).input_ids
    response_tokens = talkers["m3b"]["model"](
        torch.as_tensor(m3bconv.get_prompt()),
        do_sample=True,
        temperature=0.2,
        max_new_tokens=1024,
    )
    response_tokens = response_tokens[0][len(input_ids[0]):]
    response = talkers["m3b"]["tokenizer"].decode(response_tokens, skip_special_tokens=True).strip()
    return response

def main():
    logging.basicConfig(level=logging.INFO)

    with gr.Blocks() as demo:
        with gr.Row(variant="panel"):
            gr.Markdown("## Talk to MiniChat-3B\n\nTalk to MiniChat-3B.")
        with gr.Row(variant="panel"):
            with gr.Column(variant="panel"):
                m3b_talk_input = gr.Textbox(label="Message", placeholder="Type something here...")
            with gr.Column(variant="panel"):
                m3b_talk_output = gr.Textbox()
                m3b_talk_btn = gr.Button("Send")

        m3b_talk_btn.click(m3b_talk, inputs=m3b_talk_input, outputs=m3b_talk_output, api_name="talk_m3b")

    demo.queue(concurrency_count=1).launch()


if __name__ == "__main__":
    main()