holytinz278 commited on
Commit
4f38b0f
·
verified ·
1 Parent(s): a42dad2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -7
app.py CHANGED
@@ -4,19 +4,24 @@ from huggingface_hub import InferenceClient
4
  # Initialize the client with the fine-tuned model
5
  client = InferenceClient("Qwen/QwQ-32B-Preview") # Update if using another model
6
 
 
 
 
7
  # Function to validate inputs
8
- def validate_inputs(max_tokens, temperature, top_p):
9
- if not (1 <= max_tokens <= 32768):
10
- raise ValueError("Max tokens must be between 1 and 32768.")
11
  if not (0.1 <= temperature <= 4.0):
12
  raise ValueError("Temperature must be between 0.1 and 4.0.")
13
  if not (0.1 <= top_p <= 1.0):
14
  raise ValueError("Top-p must be between 0.1 and 1.0.")
15
 
 
 
 
 
16
  # Response generation
17
  def respond(message, history, system_message, max_tokens, temperature, top_p):
18
- validate_inputs(max_tokens, temperature, top_p)
19
-
20
  # Prepare messages for the model
21
  messages = [{"role": "system", "content": system_message}]
22
  for val in history:
@@ -26,8 +31,17 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
26
  messages.append({"role": "assistant", "content": val[1]})
27
  messages.append({"role": "user", "content": message})
28
 
29
- response = ""
 
 
 
 
 
 
 
 
30
 
 
31
  # Generate response with streaming
32
  for message in client.chat_completion(
33
  messages,
@@ -60,7 +74,7 @@ demo = gr.ChatInterface(
60
  respond,
61
  additional_inputs=[
62
  gr.Textbox(value=system_message, label="System message", lines=10),
63
- gr.Slider(minimum=1, maximum=32768, value=17012, step=1, label="Max new tokens"),
64
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
65
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
66
  ],
 
4
  # Initialize the client with the fine-tuned model
5
  client = InferenceClient("Qwen/QwQ-32B-Preview") # Update if using another model
6
 
7
+ # Model's token limit
8
+ MODEL_TOKEN_LIMIT = 16384
9
+
10
  # Function to validate inputs
11
+ def validate_inputs(max_tokens, temperature, top_p, input_tokens):
12
+ if max_tokens + input_tokens > MODEL_TOKEN_LIMIT:
13
+ raise ValueError(f"Max tokens + input tokens must not exceed {MODEL_TOKEN_LIMIT}. Adjust the max tokens.")
14
  if not (0.1 <= temperature <= 4.0):
15
  raise ValueError("Temperature must be between 0.1 and 4.0.")
16
  if not (0.1 <= top_p <= 1.0):
17
  raise ValueError("Top-p must be between 0.1 and 1.0.")
18
 
19
+ # Function to calculate input token count (basic approximation)
20
+ def count_tokens(messages):
21
+ return sum(len(m["content"].split()) for m in messages)
22
+
23
  # Response generation
24
  def respond(message, history, system_message, max_tokens, temperature, top_p):
 
 
25
  # Prepare messages for the model
26
  messages = [{"role": "system", "content": system_message}]
27
  for val in history:
 
31
  messages.append({"role": "assistant", "content": val[1]})
32
  messages.append({"role": "user", "content": message})
33
 
34
+ # Calculate input token count
35
+ input_tokens = count_tokens(messages)
36
+ max_allowed_tokens = MODEL_TOKEN_LIMIT - input_tokens
37
+
38
+ # Ensure max_tokens does not exceed the model's token limit
39
+ if max_tokens > max_allowed_tokens:
40
+ max_tokens = max_allowed_tokens
41
+
42
+ validate_inputs(max_tokens, temperature, top_p, input_tokens)
43
 
44
+ response = ""
45
  # Generate response with streaming
46
  for message in client.chat_completion(
47
  messages,
 
74
  respond,
75
  additional_inputs=[
76
  gr.Textbox(value=system_message, label="System message", lines=10),
77
+ gr.Slider(minimum=1, maximum=16384, value=1000, step=1, label="Max new tokens"), # Default fixed
78
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
79
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
80
  ],