Spaces:
Sleeping
Sleeping
File size: 1,808 Bytes
53f8a32 f4fe081 da8a172 d487976 da8a172 dc422ae d487976 e71462a d487976 53f8a32 d487976 53f8a32 d487976 53f8a32 4263bcd 53f8a32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import logging
from typing import cast
from threading import Lock
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from conversation import get_default_conv_template
import gradio as gr
talkers = {
"m3b": {
"tokenizer": AutoTokenizer.from_pretrained("GeneZC/MiniChat-3B", use_fast=False),
"model": AutoModelForCausalLM.from_pretrained("GeneZC/MiniChat-3B", device_map="auto", low_cpu_mem_usage=True),
"conv": get_default_conv_template("minichat")
}
}
def m3b_talk(text):
m3bconv = talkers["m3b"]["conv"]
m3bconv.append_message(m3bconv.roles[0], text)
m3bconv.append_message(m3bconv.roles[1], None)
input_ids = talkers["m3b"]["tokenizer"]([text]).input_ids
response_tokens = talkers["m3b"]["model"](
torch.as_tensor(m3bconv.get_prompt()),
do_sample=True,
temperature=0.2,
max_new_tokens=1024,
)
response_tokens = response_tokens[0][len(input_ids[0]):]
response = talkers["m3b"]["tokenizer"].decode(response_tokens, skip_special_tokens=True).strip()
return response
def main():
logging.basicConfig(level=logging.INFO)
with gr.Blocks() as demo:
with gr.Row(variant="panel"):
gr.Markdown("## Talk to MiniChat-3B\n\nTalk to MiniChat-3B.")
with gr.Row(variant="panel"):
with gr.Column(variant="panel"):
m3b_talk_input = gr.Textbox(label="Message", placeholder="Type something here...")
with gr.Column(variant="panel"):
m3b_talk_output = gr.Textbox()
m3b_talk_btn = gr.Button("Send")
m3b_talk_btn.click(m3b_talk, inputs=m3b_talk_input, outputs=m3b_talk_output, api_name="talk_m3b")
demo.queue(concurrency_count=1).launch()
if __name__ == "__main__":
main()
|