import logging from typing import cast from threading import Lock from transformers import AutoModelForCausalLM, AutoTokenizer import torch from conversation import get_default_conv_template import gradio as gr talkers = { "m3b": { "tokenizer": AutoTokenizer.from_pretrained("GeneZC/MiniChat-3B", use_fast=False), "model": AutoModelForCausalLM.from_pretrained("GeneZC/MiniChat-3B", device_map="auto", low_cpu_mem_usage=True), "conv": get_default_conv_template("minichat") } } def m3b_talk(text): m3bconv = talkers["m3b"]["conv"] m3bconv.append_message(m3bconv.roles[0], text) m3bconv.append_message(m3bconv.roles[1], None) input_ids = talkers["m3b"]["tokenizer"]([text]).input_ids response_tokens = talkers["m3b"]["model"]( torch.as_tensor(m3bconv.get_prompt()), do_sample=True, temperature=0.2, max_new_tokens=1024, ) response_tokens = response_tokens[0][len(input_ids[0]):] response = talkers["m3b"]["tokenizer"].decode(response_tokens, skip_special_tokens=True).strip() return response def main(): logging.basicConfig(level=logging.INFO) with gr.Blocks() as demo: with gr.Row(variant="panel"): gr.Markdown("## Talk to MiniChat-3B\n\nTalk to MiniChat-3B.") with gr.Row(variant="panel"): with gr.Column(variant="panel"): m3b_talk_input = gr.Textbox(label="Message", placeholder="Type something here...") with gr.Column(variant="panel"): m3b_talk_output = gr.Textbox() m3b_talk_btn = gr.Button("Send") m3b_talk_btn.click(m3b_talk, inputs=m3b_talk_input, outputs=m3b_talk_output, api_name="talk_m3b") demo.queue(concurrency_count=1).launch() if __name__ == "__main__": main()