File size: 1,739 Bytes
f96f74e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import logging
from typing import cast
from threading import Lock
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

import torch

from conversation import get_default_conv_template
import gradio as gr
from llama_cpp import Llama
import json

from huggingface_hub import hf_hub_download

model_path = "starling-lm-7b-alpha.Q6_K.gguf"

#mdlpath = hf_hub_download(repo_id="afrideva/MiniChat-3B-GGUF", filename=model_path)

lcpp_model = Llama(model_path=model_path)

global otxt
otxt = ""

def m3b_talk(text):
    global otxt
    resp = ""
    formattedQuery = "GPT4 User: " + text + "<|end_of_text|>GPT4 Assistant:"
    r = lcpp_model(formattedQuery, stop=["GPT4 User:", "\n\n\n"], echo=True, stream=True)
    rfq = False
    for c in r:
        otxt += c["choices"][0]["text"]
        if formattedQuery in otxt and not rfq:
            otxt.replace(formattedQuery, "")
            rfq = True
        else:
            yield otxt
    print(resp)
    return otxt
    #return resp.replace(formattedQuery, "")

def main():
    global otxt
    logging.basicConfig(level=logging.INFO)

    with gr.Blocks() as demo:
        with gr.Row(variant="panel"):
            gr.Markdown("## Talk to MiniChat-3B\n\nTalk to MiniChat-3B.")
        with gr.Row(variant="panel"):
            with gr.Column(variant="panel"):
                m3b_talk_input = gr.Textbox(label="Message", placeholder="Type something here...")
            with gr.Column(variant="panel"):
                m3b_talk_output = gr.Textbox()
                m3b_talk_btn = gr.Button("Send")

        m3b_talk_btn.click(m3b_talk, inputs=m3b_talk_input, outputs=m3b_talk_output, api_name="talk_m3b")

    demo.queue().launch()


if __name__ == "__main__":
    main()