File size: 2,749 Bytes
e55bd08
 
 
 
f60e921
e55bd08
 
 
 
 
 
 
 
 
 
 
 
 
 
f60e921
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e55bd08
f60e921
131a07a
 
 
 
 
 
 
e55bd08
131a07a
f60e921
131a07a
 
f60e921
131a07a
f60e921
15d1015
 
f60e921
 
e55bd08
f60e921
 
 
 
e55bd08
 
f60e921
e55bd08
f60e921
 
e55bd08
f60e921
 
 
e55bd08
 
0474700
e55bd08
e7e3b25
e55bd08
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import spaces
from threading import Thread
from typing import Iterator

# Load model and tokenizer
model_name = "Magpie-Align/MagpieLM-4B-Chat-v0.1"

device = "cuda" # the device to load the model onto
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto"
)
model.to(device)

MAX_INPUT_TOKEN_LENGTH = 4096  # You may need to adjust this value

@spaces.GPU
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

demo = gr.ChatInterface(
    generate,
    additional_inputs=[
        gr.Textbox(value="You are Magpie, a friendly Chatbot.", label="System prompt"),
        gr.Slider(minimum=1, maximum=2048, value=1024, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.6, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p"),
        gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"),
        gr.Slider(minimum=0.5, maximum=2.0, value=1.2, step=0.1, label="Repetition Penalty"),
    ],
)

if __name__ == "__main__":
    demo.queue()
    demo.launch(share=True)