Spaces:
Sleeping
Sleeping
File size: 4,324 Bytes
5ce6989 720c059 5ce6989 720c059 5ce6989 720c059 5ce6989 720c059 5ce6989 720c059 5ce6989 720c059 5ce6989 720c059 5ce6989 720c059 5ce6989 720c059 5ce6989 dc48636 5ce6989 720c059 5ce6989 720c059 5ce6989 dc48636 5ce6989 720c059 5ce6989 720c059 5ce6989 720c059 5ce6989 720c059 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import gradio as gr
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
from unsloth.chat_templates import get_chat_template
from unsloth import FastLanguageModel
import torch
PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
</div>
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="umair894/llama3",
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
)
FastLanguageModel.for_inference(model)
# Apply chat template to the tokenizer
tokenizer = get_chat_template(
tokenizer,
chat_template="llama-3", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth
mapping={"role": "from", "content": "value", "user": "human", "assistant": "gpt"}, # ShareGPT style
map_eos_token=True, # Maps to </s> instead
)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("")
]
# Check if terminators are None and provide a default value if needed
terminators = [token_id for token_id in terminators if token_id is not None]
if not terminators:
terminators = [tokenizer.eos_token_id] # Ensure there is a valid EOS token
def chat_llama3_8b(message: str,
history: list,
temperature: float,
max_new_tokens: int
) -> str:
"""
Generate a streaming response using the llama3-8b model.
Args:
message (str): The input message.
history (list): The conversation history used by ChatInterface.
temperature (float): The temperature for generating the response.
max_new_tokens (int): The maximum number of new tokens to generate.
Returns:
str: The generated response.
"""
conversation = []
for user, assistant in history:
conversation.extend([{"from": "human", "value": user}, {"from": "gpt", "value": assistant}])
conversation.append({"from": "human", "value": message})
input_ids = tokenizer.apply_chat_template(
conversation,
tokenize=True,
add_generation_prompt=True, # Must add for generation
return_tensors="pt",
).to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=terminators,
)
if temperature == 0:
generate_kwargs['do_sample'] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Gradio block
chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
with gr.Blocks(fill_height=True, css=css) as demo:
gr.ChatInterface(
fn=chat_llama3_8b,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(minimum=0,
maximum=1,
step=0.1,
value=0.95,
label="Temperature",
render=False),
gr.Slider(minimum=128,
maximum=4096,
step=1,
value=512,
label="Max new tokens",
render=False ),
],
examples=[
['How can i file for a student loan case?']
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch(debug=True) |