File size: 1,737 Bytes
31ce914
533b084
2f6a7e7
738953f
f78f29d
55b11ff
 
 
f16951b
738953f
f16951b
55b11ff
f16951b
 
 
55b11ff
738953f
e15a09f
31ce914
738953f
 
 
 
 
 
 
 
 
 
 
 
 
 
1091ed2
55b11ff
31ce914
738953f
 
 
31cf2be
baf9a7f
738953f
 
 
533b084
edcd873
f227238
31ce914
 
f16951b
 
 
 
edcd873
f16951b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import kminterface
import requests
import json

        
def tokenize(text):
    return text
    # return tok.encode(text, add_special_tokens=False)
    
def format_prompt(message, history):
    prompt = ""
    for user_prompt, bot_response in history:
        prompt += "<s>" + tokenize("[INST]") + tokenize(user_prompt) + tokenize("[/INST]")
        prompt += tokenize(bot_response) + "</s> "
    prompt += tokenize("[INST]") + tokenize(message) + tokenize("[/INST]")
    return prompt

def generate(prompt, history, system_prompt, temperature=0.2, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
    print(type(history), history)
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)

    stream = kminterface.client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        print(response.token.text + "/n")
        output += response.token.text
        yield output
    return output

    

demo = kminterface.gr.ChatInterface(fn=generate, 
                        chatbot=kminterface.mychatbot,
                        additional_inputs=kminterface.additional_inputs,
                        title="Kamran's Mixtral 8x7b Chat",
                        retry_btn=None,
                        undo_btn=None
                       )

demo.queue().launch(show_api=False)