File size: 3,723 Bytes
3981ed2
 
 
5f4691e
ae519a4
3981ed2
4470c09
 
ae519a4
5f4691e
 
 
 
2c3da68
ae519a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f4691e
ae519a4
3981ed2
 
 
 
 
 
 
 
3187f15
3981ed2
ae519a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3981ed2
ae519a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f4691e
ae519a4
 
5f4691e
 
 
 
 
ae519a4
2c3da68
 
5213d0d
2c3da68
5f4691e
ae519a4
5f4691e
2c3da68
5f4691e
2c3da68
5f4691e
 
 
 
 
 
ae519a4
5f4691e
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

import gradio as gr

#from unsloth import FastLanguageModel
from peft import AutoPeftModelForCausalLM
from transformers import TextIteratorStreamer, AutoTokenizer
from threading import Thread

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""

# client = InferenceClient()
class MyModel:
    def __init__(self):
        self.client = None
        self.current_model = ""
        self.tokenizer = None

    def respond(
        self,
        message,
        history: list[tuple[str, str]],
        model,
        system_message,
        max_tokens,
        temperature,
        min_p,
    ):
        if model != self.current_model or self.current_model is None:
            # client, tokenizer = FastLanguageModel.from_pretrained(
            #     model_name = model,
            #     max_seq_length = 2048,
            #     dtype = None,
            #     load_in_4bit = True,
            # )
            # FastLanguageModel.for_inference(client) # Enable native 2x faster inference
            tokenizer = AutoTokenizer.from_pretrained(model)
            client = AutoPeftModelForCausalLM.from_pretrained(model, load_in_4bit=True)

            self.client = client
            self.tokenizer = tokenizer
            self.current_model = model
        
        text_streamer = TextIteratorStreamer(self.tokenizer, skip_prompt = True)

        messages = [{"role": "system", "content": system_message}]

        for val in history:
            if val[0]:
                messages.append({"role": "user", "content": val[0]})
            if val[1]:
                messages.append({"role": "assistant", "content": val[1]})

        messages.append({"role": "user", "content": message})

        inputs = self.tokenizer.apply_chat_template(
            messages,
            tokenize = True,
            add_generation_prompt = True, # Must add for generation
            return_tensors = "pt",
        )
        
        generation_kwargs = dict(input_ids=inputs, streamer=text_streamer, max_new_tokens=max_tokens, use_cache=True, temperature=temperature, min_p=min_p)
        thread = Thread(target=self.client.generate, kwargs=generation_kwargs)
        thread.start()

        response = ""

        for new_text in text_streamer:
            response += new_text
            yield response.strip("<|eot_id|>")

        # for message in client.chat_completion(
        #     messages,
        #     max_tokens=max_tokens,
        #     stream=True,
        #     temperature=temperature,
        #     top_p=top_p,
        #     model=model,
        # ):
        #     token = message.choices[0].delta.content

        #     response += token
        #     yield response


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
my_model = MyModel()
model_choices = [
    "lab2-as/lora_model",
    "lab2-as/lora_model_no_quant",
]
demo = gr.ChatInterface(
    my_model.respond,
    additional_inputs=[
        gr.Dropdown(choices=model_choices, value=model_choices[0], label="Select Model"),
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=128, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Min-p (nucleus sampling)",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch()