seyf1elislam's picture
Update app.py
48766de verified
raw
history blame
3.47 kB
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
import os
from threading import Thread
import spaces
import time
import subprocess
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-small-8k-instruct",
torch_dtype="auto",
trust_remote_code=True,
)
tok = AutoTokenizer.from_pretrained("microsoft/Phi-3-small-8k-instruct",trust_remote_code=True,)
terminators = [
tok.eos_token_id,
]
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
device = torch.device("cpu")
print("Using CPU")
model = model.to(device)
@spaces.GPU(duration=60)
def chat(message, history,system_prompt, temperature, do_sample, max_tokens, top_k, repetition_penalty, top_p):
chat = [
{"role": "assistant", "content": system_prompt}
]
for item in history:
chat.append({"role": "user", "content": item[0]})
if item[1] is not None:
chat.append({"role": "assistant", "content": item[1]})
chat.append({"role": "user", "content": message})
messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
model_inputs = tok([messages], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(
tok, timeout=20.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=terminators,
top_k=top_k,
repetition_penalty=repetition_penalty,
top_p=top_p
)
if temperature == 0:
generate_kwargs["do_sample"] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
yield partial_text
demo = gr.ChatInterface(
fn=chat,
examples=[["Write me a poem about Machine Learning."],
["write fibonacci sequence in python"],
["who won the world cup in 2018?"],
["when was the first computer invented?"],
],
additional_inputs_accordion=gr.Accordion(
label="⚙️ Parameters", open=False, render=False
),
additional_inputs=[
gr.Textbox("Perform the task to the best of your ability.", label="System prompt"),
gr.Slider(
minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False
),
gr.Checkbox(label="Sampling", value=True),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=512,
label="Max new tokens",
render=False,
),
gr.Slider(1, 80, 40, label="Top K sampling"),
gr.Slider(0, 2, 1.1, label="Repetition penalty"),
gr.Slider(0, 1, 0.95, label="Top P sampling"),
],
stop_btn="Stop Generation",
title="Chat With Phi-3-Small-8k-7b-Instruct",
description="[microsoft/Phi-3-small-8k-instruct](https://huggingface.co/microsoft/Phi-3-small-8k-instruct)",
css="footer {visibility: hidden}",
theme="NoCrypt/[email protected]",
)
demo.launch()