File size: 4,324 Bytes
5ce6989
 
 
 
720c059
 
 
5ce6989
 
 
 
720c059
5ce6989
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
720c059
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ce6989
 
720c059
5ce6989
 
720c059
 
 
 
 
5ce6989
720c059
 
 
 
5ce6989
 
 
 
 
 
 
 
 
 
720c059
5ce6989
 
720c059
 
 
 
 
 
 
 
 
5ce6989
 
 
 
720c059
5ce6989
 
 
dc48636
5ce6989
 
720c059
5ce6989
 
 
 
 
 
 
 
 
 
 
 
720c059
5ce6989
 
 
 
 
 
 
 
 
dc48636
 
 
 
 
 
5ce6989
 
 
 
 
 
720c059
5ce6989
720c059
 
5ce6989
720c059
5ce6989
 
 
720c059
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
from unsloth.chat_templates import get_chat_template
from unsloth import FastLanguageModel
import torch


PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">

</div>
"""

css = """
h1 {
  text-align: center;
  display: block;
}
#duplicate-button {
  margin: auto;
  color: white;
  background: #1565c0;
  border-radius: 100vh;
}
"""

max_seq_length = 2048  # Choose any! We auto support RoPE Scaling internally!
dtype = None  # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True  # Use 4bit quantization to reduce memory usage. Can be False.

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="umair894/llama3",
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
    
)
FastLanguageModel.for_inference(model)

# Apply chat template to the tokenizer
tokenizer = get_chat_template(
    tokenizer,
    chat_template="llama-3",  # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth
    mapping={"role": "from", "content": "value", "user": "human", "assistant": "gpt"},  # ShareGPT style
    map_eos_token=True,  # Maps to </s> instead
)

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("")
]

# Check if terminators are None and provide a default value if needed
terminators = [token_id for token_id in terminators if token_id is not None]
if not terminators:
    terminators = [tokenizer.eos_token_id]  # Ensure there is a valid EOS token

def chat_llama3_8b(message: str, 
                   history: list, 
                   temperature: float, 
                   max_new_tokens: int
                  ) -> str:
    """
    Generate a streaming response using the llama3-8b model.
    Args:
        message (str): The input message.
        history (list): The conversation history used by ChatInterface.
        temperature (float): The temperature for generating the response.
        max_new_tokens (int): The maximum number of new tokens to generate.
    Returns:
        str: The generated response.
    """

    conversation = []
    for user, assistant in history:
        conversation.extend([{"from": "human", "value": user}, {"from": "gpt", "value": assistant}])
    conversation.append({"from": "human", "value": message})

    input_ids = tokenizer.apply_chat_template(
        conversation,
        tokenize=True,
        add_generation_prompt=True,  # Must add for generation
        return_tensors="pt",
    ).to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        eos_token_id=terminators,
    )

    if temperature == 0:
        generate_kwargs['do_sample'] = False
        
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

# Gradio block
chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')

with gr.Blocks(fill_height=True, css=css) as demo:
    
    gr.ChatInterface(
        fn=chat_llama3_8b,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Slider(minimum=0,
                      maximum=1, 
                      step=0.1,
                      value=0.95, 
                      label="Temperature", 
                      render=False),
            gr.Slider(minimum=128, 
                      maximum=4096,
                      step=1,
                      value=512, 
                      label="Max new tokens", 
                      render=False ),
        ],
        examples=[
            ['How can i file for a student loan case?']
        ],
        cache_examples=False,
    )
    
    
if __name__ == "__main__":
    demo.launch(debug=True)