File size: 3,753 Bytes
0b15f14
140793a
1997cd5
0b15f14
140793a
360f04a
 
2147ae4
0b15f14
10a0a97
140793a
 
e1eb2b8
140793a
 
 
e1eb2b8
140793a
 
 
 
360f04a
140793a
 
fec6802
 
360f04a
fec6802
 
140793a
 
 
 
2147ae4
 
 
efc5592
140793a
 
2147ae4
83a6345
140793a
 
 
 
 
 
 
c3acc2f
140793a
 
 
 
 
 
 
83a6345
140793a
c3acc2f
140793a
a68ea86
 
 
e93d658
5bdd91c
 
a68ea86
140793a
07466ed
a68ea86
b525961
 
 
 
 
 
 
 
6ae1c70
b525961
 
07466ed
706d17f
fec6802
140793a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f441bd4
140793a
 
 
 
 
 
 
f441bd4
 
140793a
 
 
 
 
 
 
 
 
 
 
 
0b15f14
140793a
0b15f14
140793a
53aa7e6
140793a
 
 
 
 
 
d1590ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import gradio as gr
from huggingface_hub import InferenceClient

HF_TOKEN = os.environ.get("HF_TOKEN", None)
API_URL = "meta-llama/Llama-2-70b-chat-hf"
API_URL_2 = "codellama/CodeLlama-34b-Instruct-hf"
BOT_NAME = "Assistant"

STOP_SEQUENCES = ["\nUser:", " User:", "###", "</s>"]

EXAMPLES = [
    ["Hey LLAMA! Any recommendations for my holidays in Abu Dhabi?"],
    ["What's the Everett interpretation of quantum mechanics?"],
    ["Give me a list of the top 10 dive sites you would recommend around the world."],
    ["Can you tell me more about deep-water soloing?"],
    ["Can you write a short tweet about the release of our latest AI model, LLAMA LLM?"]
    ]

client = InferenceClient(
    API_URL,
    token=HF_TOKEN,
)

client2 = InferenceClient(
    API_URL_2,
    token=HF_TOKEN,
)

def format_prompt(message, history, system_prompt):
  prompt = ""
  if system_prompt:
    prompt += f"System: {system_prompt}\n"
  for user_prompt, bot_response in history:
    prompt += f"User: {user_prompt}\n"
    prompt += f"{BOT_NAME}: {bot_response}\n"
  prompt += f"""User: {message}\n{BOT_NAME}:"""
  return prompt

seed = 42

def generate(
    prompt, history, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)
    global seed
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        stop_sequences=STOP_SEQUENCES,
        do_sample=True,
        seed=seed,
    )
    seed = seed + 1
    formatted_prompt = format_prompt(prompt, history, system_prompt)

    cli = client
    status = cli.get_model_status()
    print(f"Model 1 status: {status}")
    if not status.loaded and status.state == 'Loadable':
        cli.text_generation('Hello!', **generate_kwargs, return_full_text=False)
        cli = client2

    try:
        stream = cli.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
        output = ""

        for response in stream:
            output += response.token.text
    
            for stop_str in STOP_SEQUENCES:
                if output.endswith(stop_str):
                    output = output[:-len(stop_str)]
#                    output = output.rstrip()
                    yield output
            yield output
    except Exception as e:
        raise gr.Error(f"Client error while generating: {e}")
    return output

additional_inputs=[
    gr.Textbox("", label="Optional system prompt"),
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=256,
        minimum=0,
        maximum=3000,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.90,
        minimum=0.01,
        maximum=0.99,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]

with gr.Blocks() as demo:
    
    gr.ChatInterface(
        generate, 
        examples=EXAMPLES,
        additional_inputs=additional_inputs,
    ) 

demo.queue(concurrency_count=100, api_open=False).launch(show_api=False)