File size: 2,890 Bytes
86d9217
 
 
cda6239
a7a95f3
86d9217
cda6239
 
 
86d9217
401be5d
 
 
 
 
 
 
 
 
 
 
 
 
86d9217
 
401be5d
 
 
86d9217
cda6239
 
 
 
401be5d
86d9217
401be5d
86d9217
401be5d
 
86d9217
 
 
401be5d
 
 
 
 
 
f154bf1
 
86d9217
 
 
 
f154bf1
86d9217
f154bf1
 
 
 
86d9217
1c91ea3
cda6239
115746c
cda6239
 
1c91ea3
 
cda6239
1c91ea3
 
cda6239
401be5d
 
 
cda6239
1c91ea3
 
cda6239
 
401be5d
cda6239
 
86d9217
f154bf1
86d9217
401be5d
86d9217
401be5d
 
86d9217
 
 
 
f154bf1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
from huggingface_hub import InferenceClient

# Default client with the first model
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")

# Function to switch between models based on selection
def switch_client(model_name: str):
    return InferenceClient(model_name)

# Define presets for each model
presets = {
    "mistralai/Mistral-7B-Instruct-v0.3": {
        "Fast": {"max_tokens": 256, "temperature": 1.0, "top_p": 0.9},
        "Normal": {"max_tokens": 512, "temperature": 0.7, "top_p": 0.95},
        "Quality": {"max_tokens": 1024, "temperature": 0.5, "top_p": 0.90},
        "Unreal Performance": {"max_tokens": 2048, "temperature": 0.6, "top_p": 0.75},
    }
}

# Fixed system message
SYSTEM_MESSAGE = "Lake 1 Base"

def respond(
    message,
    history: list,
    model_name,
    preset_name
):
    # Switch client based on model selection
    global client
    client = switch_client(model_name)
    
    messages = [{"role": "system", "content": SYSTEM_MESSAGE}]

    # Ensure history is a list of dictionaries
    for val in history:
        if isinstance(val, dict) and 'role' in val and 'content' in val:
            messages.append({"role": val['role'], "content": val['content']})

    messages.append({"role": "user", "content": message})

    # Get the preset settings
    preset = presets[model_name][preset_name]
    max_tokens = preset["max_tokens"]
    temperature = preset["temperature"]
    top_p = preset["top_p"]

    # Get the response from the model
    response = client.chat_completion(
        messages,
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
    )

    # Extract the content from the response
    final_response = response.choices[0].message['content']
    
    return final_response

# Model names and their pseudonyms
model_choices = [
    ("mistralai/Mistral-7B-Instruct-v0.3", "Lake 1 Base")
]

# Convert pseudonyms to model names for the dropdown
pseudonyms = [model[1] for model in model_choices]

# Function to handle model selection and pseudonyms
def respond_with_pseudonym(
    message,
    history: list,
    selected_pseudonym,
    selected_preset
):
    # Find the actual model name from the pseudonym
    model_name = next(model[0] for model in model_choices if model[1] == selected_pseudonym)
    
    # Call the existing respond function
    response = respond(message, history, model_name, selected_preset)
    
    return response

# Gradio Chat Interface
demo = gr.ChatInterface(
    fn=respond_with_pseudonym,
    additional_inputs=[
        gr.Dropdown(choices=list(presets["mistralai/Mistral-7B-Instruct-v0.3"].keys()), label="Select Preset", value="Fast"),  # Preset selection dropdown
        gr.Dropdown(choices=pseudonyms, label="Select Model", value=pseudonyms[0])  # Pseudonym selection dropdown
    ],
)

if __name__ == "__main__":
    demo.launch()