Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from threading import Thread | |
from unsloth.chat_templates import get_chat_template | |
from unsloth import FastLanguageModel | |
import torch | |
PLACEHOLDER = """ | |
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;"> | |
</div> | |
""" | |
css = """ | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
#duplicate-button { | |
margin: auto; | |
color: white; | |
background: #1565c0; | |
border-radius: 100vh; | |
} | |
""" | |
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally! | |
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ | |
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False. | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name="umair894/llama3", | |
max_seq_length=max_seq_length, | |
dtype=dtype, | |
load_in_4bit=load_in_4bit, | |
) | |
FastLanguageModel.for_inference(model) | |
# Apply chat template to the tokenizer | |
tokenizer = get_chat_template( | |
tokenizer, | |
chat_template="llama-3", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth | |
mapping={"role": "from", "content": "value", "user": "human", "assistant": "gpt"}, # ShareGPT style | |
map_eos_token=True, # Maps to </s> instead | |
) | |
terminators = [ | |
tokenizer.eos_token_id, | |
tokenizer.convert_tokens_to_ids("") | |
] | |
# Check if terminators are None and provide a default value if needed | |
terminators = [token_id for token_id in terminators if token_id is not None] | |
if not terminators: | |
terminators = [tokenizer.eos_token_id] # Ensure there is a valid EOS token | |
def chat_llama3_8b(message: str, | |
history: list, | |
temperature: float, | |
max_new_tokens: int | |
) -> str: | |
""" | |
Generate a streaming response using the llama3-8b model. | |
Args: | |
message (str): The input message. | |
history (list): The conversation history used by ChatInterface. | |
temperature (float): The temperature for generating the response. | |
max_new_tokens (int): The maximum number of new tokens to generate. | |
Returns: | |
str: The generated response. | |
""" | |
conversation = [] | |
for user, assistant in history: | |
conversation.extend([{"from": "human", "value": user}, {"from": "gpt", "value": assistant}]) | |
conversation.append({"from": "human", "value": message}) | |
input_ids = tokenizer.apply_chat_template( | |
conversation, | |
tokenize=True, | |
add_generation_prompt=True, # Must add for generation | |
return_tensors="pt", | |
).to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
temperature=temperature, | |
eos_token_id=terminators, | |
) | |
if temperature == 0: | |
generate_kwargs['do_sample'] = False | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
# Gradio block | |
chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface') | |
with gr.Blocks(fill_height=True, css=css) as demo: | |
gr.ChatInterface( | |
fn=chat_llama3_8b, | |
chatbot=chatbot, | |
fill_height=True, | |
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), | |
additional_inputs=[ | |
gr.Slider(minimum=0, | |
maximum=1, | |
step=0.1, | |
value=0.95, | |
label="Temperature", | |
render=False), | |
gr.Slider(minimum=128, | |
maximum=4096, | |
step=1, | |
value=512, | |
label="Max new tokens", | |
render=False ), | |
], | |
examples=[ | |
['How can i file for a student loan case?'] | |
], | |
cache_examples=False, | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) |