File size: 3,145 Bytes
ddb290c
 
 
3a4a539
ce5bdbb
3a4a539
4f38b0f
 
 
3a4a539
4f38b0f
 
 
3a4a539
 
 
 
 
4f38b0f
 
 
 
3a4a539
 
 
ddb290c
 
3a4a539
ddb290c
3a4a539
ddb290c
 
 
4f38b0f
 
 
 
 
 
 
 
 
3a4a539
4f38b0f
3a4a539
ddb290c
 
 
 
 
 
 
 
 
 
 
3a4a539
 
 
 
 
 
ddb290c
3a4a539
 
 
 
 
 
ddb290c
3a4a539
 
ddb290c
 
 
3a4a539
4f38b0f
ddb290c
3a4a539
ddb290c
 
 
 
3a4a539
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
from huggingface_hub import InferenceClient

# Initialize the client with the fine-tuned model
client = InferenceClient("deepseek-ai/DeepSeek-R1")  # Update if using another model

# Model's token limit
MODEL_TOKEN_LIMIT = 16384

# Function to validate inputs
def validate_inputs(max_tokens, temperature, top_p, input_tokens):
    if max_tokens + input_tokens > MODEL_TOKEN_LIMIT:
        raise ValueError(f"Max tokens + input tokens must not exceed {MODEL_TOKEN_LIMIT}. Adjust the max tokens.")
    if not (0.1 <= temperature <= 4.0):
        raise ValueError("Temperature must be between 0.1 and 4.0.")
    if not (0.1 <= top_p <= 1.0):
        raise ValueError("Top-p must be between 0.1 and 1.0.")

# Function to calculate input token count (basic approximation)
def count_tokens(messages):
    return sum(len(m["content"].split()) for m in messages)

# Response generation
def respond(message, history, system_message, max_tokens, temperature, top_p):
    # Prepare messages for the model
    messages = [{"role": "system", "content": system_message}]
    for val in history:
        if val[0]:  # User's message
            messages.append({"role": "user", "content": val[0]})
        if val[1]:  # Assistant's response
            messages.append({"role": "assistant", "content": val[1]})
    messages.append({"role": "user", "content": message})

    # Calculate input token count
    input_tokens = count_tokens(messages)
    max_allowed_tokens = MODEL_TOKEN_LIMIT - input_tokens

    # Ensure max_tokens does not exceed the model's token limit
    if max_tokens > max_allowed_tokens:
        max_tokens = max_allowed_tokens

    validate_inputs(max_tokens, temperature, top_p, input_tokens)
    
    response = ""
    # Generate response with streaming
    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content
        response += token
        yield response

# Updated system message
system_message = """
You are an advanced AI assistant specialized in coding tasks. 
- You deliver precise, error-free code in multiple programming languages.
- Analyze queries for logical accuracy and provide optimized solutions.
- Ensure clarity, brevity, and adherence to programming standards.

Guidelines:
1. Prioritize accurate, functional code.
2. Provide explanations only when necessary for understanding.
3. Handle tasks ethically, respecting user intent and legal constraints.

Thank you for using this system. Please proceed with your query.
"""

# Gradio Interface
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value=system_message, label="System message", lines=10),
        gr.Slider(minimum=1, maximum=16384, value=1000, step=1, label="Max new tokens"),  # Default fixed
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
)

if __name__ == "__main__":
    demo.launch()