MiniChat-3B / app.py
Samuel L Meyers
Inital MiniChat test
da8a172
raw
history blame
1.81 kB
import logging
from typing import cast
from threading import Lock
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from conversation import get_default_conv_template
import gradio as gr
talkers = {
"m3b": {
"tokenizer": AutoTokenizer.from_pretrained("GeneZC/MiniChat-3B", use_fast=False),
"model": AutoModelForCausalLM.from_pretrained("GeneZC/MiniChat-3B", device_map="auto", low_cpu_mem_usage=True),
"conv": get_default_conv_template("minichat")
}
}
def m3b_talk(text):
m3bconv = talkers["m3b"]["conv"]
m3bconv.append_message(m3bconv.roles[0], text)
m3bconv.append_message(m3bconv.roles[1], None)
input_ids = talkers["m3b"]["tokenizer"]([text]).input_ids
response_tokens = talkers["m3b"]["model"](
torch.as_tensor(m3bconv.get_prompt()),
do_sample=True,
temperature=0.2,
max_new_tokens=1024,
)
response_tokens = response_tokens[0][len(input_ids[0]):]
response = talkers["m3b"]["tokenizer"].decode(response_tokens, skip_special_tokens=True).strip()
return response
def main():
logging.basicConfig(level=logging.INFO)
with gr.Blocks() as demo:
with gr.Row(variant="panel"):
gr.Markdown("## Talk to MiniChat-3B\n\nTalk to MiniChat-3B.")
with gr.Row(variant="panel"):
with gr.Column(variant="panel"):
m3b_talk_input = gr.Textbox(label="Message", placeholder="Type something here...")
with gr.Column(variant="panel"):
m3b_talk_output = gr.Textbox()
m3b_talk_btn = gr.Button("Send")
m3b_talk_btn.click(m3b_talk, inputs=m3b_talk_input, outputs=m3b_talk_output, api_name="talk_m3b")
demo.queue(concurrency_count=1).launch()
if __name__ == "__main__":
main()