File size: 3,403 Bytes
92c50d2
 
24c93e5
d6ed3af
f16d5ff
10ae4aa
 
d6ed3af
 
 
 
 
92c50d2
 
 
4deb2e9
91fb6a1
 
46e0200
92c50d2
7958865
02c7617
93116a8
46e0200
27dde00
 
 
02c7617
 
46e0200
02c7617
46e0200
 
 
5dfffe1
0a93e35
 
5dfffe1
 
 
 
 
93116a8
 
 
3f4b192
 
02c7617
 
 
46e0200
02c7617
 
3f4b192
 
 
 
92c50d2
 
 
 
 
 
 
 
 
 
 
 
02c7617
92c50d2
 
 
 
 
 
 
 
 
91fb6a1
 
 
92c50d2
74cc87b
49a62d8
92c50d2
27dde00
9f95576
92c50d2
 
46e0200
 
92c50d2
 
3d50071
92c50d2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
from huggingface_hub import InferenceClient
import spaces
import torch
import os

model=""
print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
# Tesla T4

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
# client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
# client = InferenceClient("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", token=os.getenv('deepseekv2'))
# client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=os.getenv('deepseekv2'))
# client = InferenceClient("meta-llama/Llama-3.1-8B-Instruct", token=os.getenv('deepseekv2'))


def choose_model(model_name):
    if model_name == "DeepSeek-R1-Distill-Qwen-1.5B":
        model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

    elif model_name == "DeepSeek-R1-Distill-Qwen-32B":
        model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
        
    elif model_name == "Llama3-8b-Instruct":    
        model = "meta-llama/Meta-Llama-3-8B-Instruct"

    elif model_name == "Llama3.1-8b-Instruct":
        model = "meta-llama/Llama-3.1-8B-Instruct"

    elif model_name == "Llama2-13b-chat":
        model = "meta-llama/Llama-2-13b-chat-hf"

    elif model_name == "Gemma-2-2b":
        model = "google/gemma-2-2b-it"

    elif model_name == "Mixtral-8x7B-Instruct":
        model = "mistralai/Mixtral-8x7B-Instruct-v0.1"

    elif model_name == "Zephr-7b-beta":
        model = "HuggingFaceH4/zephyr-7b-beta"
    
    return model
    

@spaces.GPU(duration=1)
def respond(message, history: list[tuple[str, str]], model, system_message, max_tokens, temperature, top_p):

    print(model)
    model_name = choose_model(model)

    client = InferenceClient(model_name, token=os.getenv('deepseekv2'))
    
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""

    for message in client.chat_completion(messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p):
        token = message.choices[0].delta.content

        response += token
        yield response


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""


    
demo = gr.ChatInterface(
    respond,

    additional_inputs=[
        gr.Dropdown(["DeepSeek-R1-Distill-Qwen-1.5B", "DeepSeek-R1-Distill-Qwen-32B", "Gemma-2-2b", "Llama2-13b-chat", "Llama3-8b-Instruct", "Llama3.1-8b-Instruct", "Mixtral-8x7B-Instruct", "Zephr-7b-beta"], label="Select Model"),
        gr.Textbox(value="You are a friendly and helpful Chatbot, be concise and straight to the point, avoid excessive reasoning.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
    ]
)



if __name__ == "__main__":
    demo.launch()