File size: 1,653 Bytes
53f8a32
 
f4fe081
da8a172
d487976
 
 
da8a172
fe36794
2a8c299
fe36794
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc422ae
2a8c299
d487976
2a8c299
fe36794
ecd2bcd
2a8c299
ecd2bcd
fe36794
d487976
fe36794
 
 
 
d487976
53f8a32
 
 
 
d487976
 
 
 
 
 
 
 
53f8a32
d487976
53f8a32
4263bcd
53f8a32
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import logging
from typing import cast
from threading import Lock
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

from conversation import get_default_conv_template
import gradio as gr
from pyllamacpp.model import Model
import wget

"""

model = Model(model_path='/path/to/model.bin')
while True:
    try:
        prompt = input("You: ", flush=True)
        if prompt == '':
            continue
        print(f"AI:", end='')
        for token in model.generate(prompt):
            print(f"{token}", end='', flush=True)
        print()
    except KeyboardInterrupt:
        break
"""

from huggingface_hub import hf_hub_download

model_path = "minichat-3b.q8_0.gguf"

mdlpath = hf_hub_download(repo_id="afrideva/MiniChat-3B-GGUF", filename=model_path)

lcpp_model = Model(model_path=mdlpath)

def m3b_talk(text):
    resp = ""
    for token in lcpp_model.generate(text):
        resp += token
    return resp

def main():
    logging.basicConfig(level=logging.INFO)

    with gr.Blocks() as demo:
        with gr.Row(variant="panel"):
            gr.Markdown("## Talk to MiniChat-3B\n\nTalk to MiniChat-3B.")
        with gr.Row(variant="panel"):
            with gr.Column(variant="panel"):
                m3b_talk_input = gr.Textbox(label="Message", placeholder="Type something here...")
            with gr.Column(variant="panel"):
                m3b_talk_output = gr.Textbox()
                m3b_talk_btn = gr.Button("Send")

        m3b_talk_btn.click(m3b_talk, inputs=m3b_talk_input, outputs=m3b_talk_output, api_name="talk_m3b")

    demo.queue(concurrency_count=1).launch()


if __name__ == "__main__":
    main()