File size: 4,366 Bytes
140793a
0b15f14
140793a
 
0b15f14
140793a
 
0b15f14
140793a
e1eb2b8
fec6802
e1eb2b8
0b15f14
10a0a97
140793a
 
e1eb2b8
140793a
 
 
e1eb2b8
140793a
 
 
 
 
 
 
fec6802
 
 
 
 
140793a
 
 
 
 
 
e1eb2b8
140793a
 
 
 
83a6345
 
140793a
 
 
 
 
 
 
83a6345
140793a
 
 
 
 
 
 
83a6345
140793a
83a6345
140793a
 
07466ed
 
b525961
 
 
 
 
 
 
 
6ae1c70
b525961
 
07466ed
fec6802
 
 
 
140793a
fec6802
 
 
 
 
 
 
 
 
 
 
 
140793a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f441bd4
140793a
 
 
 
 
 
 
f441bd4
 
140793a
 
 
 
 
 
 
 
 
 
 
 
0b15f14
140793a
0b15f14
140793a
53aa7e6
140793a
 
 
 
 
 
d1590ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import json
import os
import shutil
import requests

import gradio as gr
from huggingface_hub import Repository, InferenceClient

HF_TOKEN = os.environ.get("HF_TOKEN", None)
API_URL = "https://api-inference.huggingface.co/models/meta-llama/Llama-2-70b-chat-hf"
API_URL_2 = "https://api-inference.huggingface.co/models/codellama/CodeLlama-34b-Instruct-hf"
BOT_NAME = "LLAMA"

STOP_SEQUENCES = ["\nUser:", " User:", "###", "</s>"]

EXAMPLES = [
    ["Hey LLAMA! Any recommendations for my holidays in Abu Dhabi?"],
    ["What's the Everett interpretation of quantum mechanics?"],
    ["Give me a list of the top 10 dive sites you would recommend around the world."],
    ["Can you tell me more about deep-water soloing?"],
    ["Can you write a short tweet about the release of our latest AI model, LLAMA LLM?"]
    ]

client = InferenceClient(
    API_URL,
    headers={"Authorization": f"Bearer {HF_TOKEN}"},
)

client2 = InferenceClient(
    API_URL_2,
    headers={"Authorization": f"Bearer {HF_TOKEN}"},
)

def format_prompt(message, history, system_prompt):
  prompt = ""
  if system_prompt:
    prompt += f"System: {system_prompt}\n"
  for user_prompt, bot_response in history:
    prompt += f"User: {user_prompt}\n"
    prompt += f"LLAMA: {bot_response}\n" # Response already contains "Falcon: "
  prompt += f"""User: {message}
Falcon:"""
  return prompt

seed = 42

def generate(
    prompt, history, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)
    global seed
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        stop_sequences=STOP_SEQUENCES,
        do_sample=True,
        seed=seed,
    )
    seed = seed + 1
    formatted_prompt = format_prompt(prompt, history, system_prompt)

    try:
        stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
        output = ""

        for response in stream:
            output += response.token.text
    
            for stop_str in STOP_SEQUENCES:
                if output.endswith(stop_str):
                    output = output[:-len(stop_str)]
#                    output = output.rstrip()
                    yield output
            yield output
    except Exception as e:
        raise gr.Error(f"Client 1 error while generating: {e}")
        try:
            stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
            output = ""

            for response in stream:
                output += response.token.text
    
                for stop_str in STOP_SEQUENCES:
                    if output.endswith(stop_str):
                        output = output[:-len(stop_str)]
#                        output = output.rstrip()
                        yield output
                yield output
        except Exception as e:
            raise gr.Error(f"Client 2 error while generating: {e}")
    return output

additional_inputs=[
    gr.Textbox("", label="Optional system prompt"),
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=256,
        minimum=0,
        maximum=3000,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.90,
        minimum=0.01,
        maximum=0.99,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]

with gr.Blocks() as demo:
    
    gr.ChatInterface(
        generate, 
        examples=EXAMPLES,
        additional_inputs=additional_inputs,
    ) 

demo.queue(concurrency_count=100, api_open=False).launch(show_api=False)