Spaces:
Sleeping
Sleeping
File size: 2,749 Bytes
e55bd08 f60e921 e55bd08 f60e921 e55bd08 f60e921 131a07a e55bd08 131a07a f60e921 131a07a f60e921 131a07a f60e921 15d1015 f60e921 e55bd08 f60e921 e55bd08 f60e921 e55bd08 f60e921 e55bd08 f60e921 e55bd08 0474700 e55bd08 e7e3b25 e55bd08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import spaces
from threading import Thread
from typing import Iterator
# Load model and tokenizer
model_name = "Magpie-Align/MagpieLM-4B-Chat-v0.1"
device = "cuda" # the device to load the model onto
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto"
)
model.to(device)
MAX_INPUT_TOKEN_LENGTH = 4096 # You may need to adjust this value
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
demo = gr.ChatInterface(
generate,
additional_inputs=[
gr.Textbox(value="You are Magpie, a friendly Chatbot.", label="System prompt"),
gr.Slider(minimum=1, maximum=2048, value=1024, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.6, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p"),
gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"),
gr.Slider(minimum=0.5, maximum=2.0, value=1.2, step=0.1, label="Repetition Penalty"),
],
)
if __name__ == "__main__":
demo.queue()
demo.launch(share=True)
|