mychat1 / app.py
basit123796's picture
Update app.py
47511ce verified
raw
history blame
4.76 kB
import gradio as gr
from gradio_client import Client
from huggingface_hub import InferenceClient
ss_client = Client("https://omnibus-html-image-current-tab.hf.space/")
models = [
"google/gemma-7b",
"google/gemma-7b-it",
"google/gemma-2b",
"google/gemma-2b-it"
]
clients = [
InferenceClient(models[0]),
InferenceClient(models[1]),
InferenceClient(models[2]),
InferenceClient(models[3]),
]
VERBOSE = False
def load_models():
return gr.update(label=models[0])
def format_prompt(message, history):
prompt = ""
if history:
for user_prompt, bot_response in history:
prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
prompt += f"<start_of_turn>model{bot_response}<end_of_turn>"
if VERBOSE:
print(prompt)
prompt += message
return prompt
def chat_inf(prompt, history, memory, temp, tokens, top_p, rep_p, chat_mem):
hist_len = 0
client = clients[0]
if not history:
history = []
hist_len = 0
if not memory:
memory = []
mem_len = 0
if memory:
for ea in memory[0 - chat_mem :]:
hist_len += len(str(ea))
in_len = len(prompt) + hist_len
if (in_len + tokens) > 8000:
history.append(
(
prompt,
"Wait, that's too many tokens, please reduce the 'Chat Memory' value, or reduce the 'Max new tokens' value",
)
)
yield history, memory
else:
generate_kwargs = dict(
temperature=temp,
max_new_tokens=tokens,
top_p=top_p,
repetition_penalty=rep_p,
do_sample=True,
)
formatted_prompt = format_prompt(prompt, memory[0 - chat_mem :])
stream = client.text_generation(
formatted_prompt,
**generate_kwargs,
stream=True,
details=True,
return_full_text=True,
)
output = ""
for response in stream:
output += response.token.text
yield [(prompt, output)], memory
history.append((prompt, output))
memory.append((prompt, output))
yield history, memory
if VERBOSE:
print("\n######### HIST " + str(in_len))
print("\n######### TOKENS " + str(tokens))
def get_screenshot(
chat: list,
height=5000,
width=600,
chatblock=[],
theme="light",
wait=3000,
header=True,
):
tog = 0
if chatblock:
tog = 3
result = ss_client.predict(
str(chat),
height,
width,
chatblock,
header,
theme,
wait,
api_name="/run_script",
)
out = f'https://omnibus-html-image-current-tab.hf.space/file={result[tog]}'
return out
def clear_fn():
return None, None, None, None
with gr.Blocks() as app:
memory = gr.State()
chat_b = gr.Chatbot(height=500)
with gr.Group():
with gr.Row():
with gr.Column(scale=3):
inp = gr.Textbox(label="Prompt")
btn = gr.Button("Chat")
with gr.Column(scale=1):
with gr.Group():
temp = gr.Slider(
label="Temperature",
step=0.01,
minimum=0.01,
maximum=1.0,
value=0.49,
)
tokens = gr.Slider(
label="Max new tokens",
value=1600,
minimum=0,
maximum=8000,
step=64,
interactive=True,
visible=True,
info="The maximum number of tokens",
)
top_p = gr.Slider(
label="Top-P",
step=0.01,
minimum=0.01,
maximum=1.0,
value=0.49,
)
rep_p = gr.Slider(
label="Repetition Penalty",
step=0.01,
minimum=0.1,
maximum=2.0,
value=0.99,
)
chat_mem = gr.Number(
label="Chat Memory",
info="Number of previous chats to retain",
value=4,
)
app.load(load_models)
chat_sub = inp.submit().then(
chat_inf, [inp, chat_b, memory, temp, tokens, top_p, rep_p, chat_mem], [chat_b, memory]
)
go = btn.click().then(
chat_inf,