File size: 2,821 Bytes
17e8c28 9fbab2d dbe7aa2 9fbab2d dbe7aa2 9fbab2d 17e8c28 dbe7aa2 17e8c28 dbe7aa2 17e8c28 9fbab2d dbe7aa2 9fbab2d dbe7aa2 9fbab2d dbe7aa2 9fbab2d dbe7aa2 9fbab2d dbe7aa2 9fbab2d dbe7aa2 9fbab2d dbe7aa2 9fbab2d dbe7aa2 9fbab2d dbe7aa2 9fbab2d dbe7aa2 9fbab2d dbe7aa2 9fbab2d dbe7aa2 9fbab2d dbe7aa2 9fbab2d dbe7aa2 9fbab2d dbe7aa2 9fbab2d dbe7aa2 17e8c28 9fbab2d dbe7aa2 17e8c28 9fbab2d 6428f94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from huggingface_hub import InferenceClient
import gradio as gr
client = InferenceClient(
"google/gemma-7b-it"
)
def format_prompt(message, history):
prompt = ""
if history:
for user_prompt, bot_response in history:
prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
prompt += f"<start_of_turn>model{bot_response}"
prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>model"
return prompt
def generate(
prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
if not history:
history = []
hist_len=0
if history:
hist_len=len(history)
print(hist_len)
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(prompt, history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
return output
additional_inputs=[
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=512,
minimum=0,
maximum=1048,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
# Create a Chatbot object with the desired height
chatbot = gr.Chatbot(height=450,
layout="bubble",
placeholder="Type here to chat...")
with gr.Blocks() as demo:
gr.HTML("<h1><center>π€ Google-Gemma-7B-Chat π¬<h1><center>")
gr.ChatInterface(
generate,
chatbot=chatbot, # Use the created Chatbot object
additional_inputs=additional_inputs,
examples=[["What is the meaning of life?"], ["Tell me something about Mt Fuji."]],
placeholder="Type here to chat..."
)
demo.queue().launch(debug=True)
|