File size: 2,611 Bytes
04c25c5
 
9b9128d
1d73b44
 
 
 
 
 
 
 
 
 
 
 
e9b47ff
d786de6
e9b47ff
8686afb
e9b47ff
8686afb
d786de6
8686afb
e9b47ff
1d73b44
 
 
 
369dc1f
2a37133
04c25c5
 
 
 
 
 
369dc1f
04c25c5
 
 
369dc1f
04c25c5
 
 
 
 
49e3627
 
04c25c5
 
 
 
 
 
 
 
 
d7942b7
 
04c25c5
cffbd3f
04c25c5
 
 
 
 
 
 
 
 
 
 
 
c96912f
d7942b7
2a37133
 
d7942b7
04c25c5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
from huggingface_hub import InferenceClient
import random
models=[
    "google/gemma-7b",
    "google/gemma-7b-it",
    "google/gemma-2b",
    "google/gemma-2b-it"
]
clients=[
InferenceClient(models[0]),
InferenceClient(models[1]),
InferenceClient(models[2]),
InferenceClient(models[3]),
]
def format_prompt(message, history):
    prompt = ""
    if history:
        #<start_of_turn>userHow does the brain work?<end_of_turn><start_of_turn>model
        for user_prompt, bot_response in history:
            prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
            prompt += f"<start_of_turn>model{bot_response}"
    prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>model"
    return prompt



def chat_inf(system_prompt,prompt,history,client_choice):
    #token max=8192
    client=clients[int(client_choice)-1]
    if not history:
        history = []
        hist_len=0
    if history:
        hist_len=len(history)
        print(hist_len)
        
    seed = random.randint(1,1111111111111111)
    generate_kwargs = dict(
        temperature=0.9,
        max_new_tokens=6000,
        top_p=0.95,
        repetition_penalty=1.0,
        do_sample=True,
        seed=seed,
    )
    #formatted_prompt=prompt   
    formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""
        
    for response in stream:
        output += response.token.text
        yield [(prompt,output)]
    history.append((prompt,output))
    yield history

def clear_fn():
    return None,None,None
with gr.Blocks() as app:
    gr.Markdown("""<center><h1>Google Gemma Models</h1><br><h3>running with Huggingface Inference Client</h3><br><h7>EXPERIMENTAL""")
    with gr.Group():
        chat_b = gr.Chatbot()
        with gr.Row():
            with gr.Column(scale=3):
                inp = gr.Textbox(label="Prompt")
                sys_inp = gr.Textbox(label="System Prompt (optional)")
                btn = gr.Button("Chat")
                
            with gr.Column(scale=1):
                with gr.Group():
                    stop_btn=gr.Button("Stop")
                    clear_btn=gr.Button("Clear")
        client_choice=gr.Dropdown(label="Models",type='index',choices=[c for c in models],value=models[0],interactive=True)
    
    go=btn.click(chat_inf,[sys_inp,inp,chat_b,client_choice],chat_b)
    stop_btn.click(None,None,None,cancels=go)
    clear_btn.click(clear_fn,None,[inp,sys_inp,chat_b])
app.launch()