jtmuller commited on
Commit
a769908
·
1 Parent(s): 914dfee

Update Space

Browse files
Files changed (1) hide show
  1. app.py +139 -30
app.py CHANGED
@@ -1,64 +1,173 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
8
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
 
 
 
 
 
 
 
 
26
  messages.append({"role": "user", "content": message})
27
 
28
- response = ""
 
29
 
30
- for message in client.chat_completion(
 
 
 
31
  messages,
32
  max_tokens=max_tokens,
33
  stream=True,
34
  temperature=temperature,
35
  top_p=top_p,
36
  ):
37
- token = message.choices[0].delta.content
38
-
39
  response += token
40
- yield response
41
 
 
 
 
 
42
 
 
 
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
 
 
 
45
  """
46
  demo = gr.ChatInterface(
47
- respond,
48
  additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  gr.Slider(
53
  minimum=0.1,
54
  maximum=1.0,
55
  value=0.95,
56
  step=0.05,
57
- label="Top-p (nucleus sampling)",
58
  ),
59
  ],
60
  )
61
 
62
-
63
  if __name__ == "__main__":
64
- demo.launch()
 
1
+ import sys
2
+ import onnxruntime as ort
3
+ import numpy as np
4
+ import string
5
+
6
+ # Transformers, HuggingFace Hub, and Gradio
7
+ from transformers import AutoTokenizer
8
  import gradio as gr
9
  from huggingface_hub import InferenceClient
10
 
11
+ # ------------------------------------------------
12
+ # Turn Detector Configuration
13
+ # ------------------------------------------------
14
+ HG_MODEL = "livekit/turn-detector" # or your HF model repo
15
+ ONNX_FILENAME = "model_quantized.onnx" # path to your ONNX file
16
+ MAX_HISTORY_TOKENS = 512
17
+ PUNCS = string.punctuation.replace("'", "")
18
+
19
+ # ------------------------------------------------
20
+ # Utility functions
21
+ # ------------------------------------------------
22
+ def softmax(logits: np.ndarray) -> np.ndarray:
23
+ exp_logits = np.exp(logits - np.max(logits))
24
+ return exp_logits / np.sum(exp_logits)
25
+
26
+ def normalize_text(text: str) -> str:
27
+ """Lowercase, strip punctuation (except single quotes), and collapse whitespace."""
28
+ def strip_puncs(text_in):
29
+ return text_in.translate(str.maketrans("", "", PUNCS))
30
+ return " ".join(strip_puncs(text).lower().split())
31
+
32
+ def calculate_eou(chat_ctx, session, tokenizer) -> float:
33
+ """
34
+ Given a conversation context (list of dicts with 'role' and 'content'),
35
+ returns the probability that the user is finished speaking.
36
+ """
37
+ # Collect normalized messages from 'user' or 'assistant' roles
38
+ normalized_ctx = []
39
+ for msg in chat_ctx:
40
+ if msg["role"] in ("user", "assistant"):
41
+ content = normalize_text(msg["content"])
42
+ if content:
43
+ normalized_ctx.append(content)
44
+
45
+ # Join them into one input string
46
+ text = " ".join(normalized_ctx)
47
+ inputs = tokenizer(
48
+ text,
49
+ return_tensors="np",
50
+ truncation=True,
51
+ max_length=MAX_HISTORY_TOKENS,
52
+ )
53
+
54
+ input_ids = np.array(inputs["input_ids"], dtype=np.int64)
55
+ # Run inference
56
+ outputs = session.run(["logits"], {"input_ids": input_ids})
57
+ logits = outputs[0][0, -1, :]
58
+
59
+ # Softmax over logits
60
+ probs = softmax(logits)
61
+ # The ID for the <|im_end|> special token
62
+ eou_token_id = tokenizer.encode("<|im_end|>")[-1]
63
+ return probs[eou_token_id]
64
+
65
+ # ------------------------------------------------
66
+ # Load ONNX session & tokenizer once
67
+ # ------------------------------------------------
68
+ print("Loading ONNX model session...")
69
+ onnx_session = ort.InferenceSession(ONNX_FILENAME, providers=["CPUExecutionProvider"])
70
+
71
+ print("Loading tokenizer...")
72
+ turn_detector_tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)
73
+
74
+ # ------------------------------------------------
75
+ # HF InferenceClient for text generation (example)
76
+ # ------------------------------------------------
77
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
78
+ # Adjust above to any other endpoint that suits your use case.
79
 
80
+ # ------------------------------------------------
81
+ # Gradio Chat Handler
82
+ # ------------------------------------------------
83
+ def respond(message, history, system_message, max_tokens, temperature, top_p):
84
+ """
85
+ This function is called on each new user message in the ChatInterface.
86
+ - 'message' is the new user input
87
+ - 'history' is a list of (user, assistant) tuples
88
+ - 'system_message' is from the system Textbox
89
+ - max_tokens, temperature, top_p come from the Sliders
90
+ """
91
 
92
+ # 1) Build a list of messages in the OpenAI-style format:
93
+ # [{'role': 'system', 'content': ...},
94
+ # {'role': 'user', 'content': ...}, ...]
 
 
 
 
 
 
95
 
96
+ messages = []
97
+ if system_message.strip():
98
+ messages.append({"role": "system", "content": system_message})
 
 
99
 
100
+ # history is a list of tuples: [(user1, assistant1), (user2, assistant2), ...]
101
+ for user_text, assistant_text in history:
102
+ if user_text:
103
+ messages.append({"role": "user", "content": user_text})
104
+ if assistant_text:
105
+ messages.append({"role": "assistant", "content": assistant_text})
106
+
107
+ # Append the new user message
108
  messages.append({"role": "user", "content": message})
109
 
110
+ # 2) Calculate EOU probability on the entire conversation
111
+ eou_prob = calculate_eou(messages, onnx_session, turn_detector_tokenizer)
112
 
113
+ # 3) Generate the assistant response from your HF model.
114
+ # (This code streams token-by-token.)
115
+ response = ""
116
+ for resp_chunk in client.chat_completion(
117
  messages,
118
  max_tokens=max_tokens,
119
  stream=True,
120
  temperature=temperature,
121
  top_p=top_p,
122
  ):
123
+ token = resp_chunk.choices[0].delta.get("content", "")
 
124
  response += token
 
125
 
126
+ # If you want to display the partial response with the EOU probability
127
+ # appended at the bottom, you can do so each step. For cleanliness,
128
+ # we'll do it in-line as a bracketed note at the end.
129
+ yield response + f"\n\n[EOU Probability: {eou_prob:.4f}]"
130
 
131
+ # ------------------------------------------------
132
+ # Gradio ChatInterface
133
+ # ------------------------------------------------
134
  """
135
+ This ChatInterface will have:
136
+ - A chat box
137
+ - A system message textbox
138
+ - 3 sliders for max_tokens, temperature, and top_p
139
  """
140
  demo = gr.ChatInterface(
141
+ fn=respond,
142
  additional_inputs=[
143
+ gr.Textbox(
144
+ value="You are a friendly Chatbot.",
145
+ label="System message",
146
+ lines=2
147
+ ),
148
+ gr.Slider(
149
+ minimum=1,
150
+ maximum=2048,
151
+ value=512,
152
+ step=1,
153
+ label="Max new tokens"
154
+ ),
155
+ gr.Slider(
156
+ minimum=0.1,
157
+ maximum=4.0,
158
+ value=0.7,
159
+ step=0.1,
160
+ label="Temperature"
161
+ ),
162
  gr.Slider(
163
  minimum=0.1,
164
  maximum=1.0,
165
  value=0.95,
166
  step=0.05,
167
+ label="Top-p (nucleus sampling)"
168
  ),
169
  ],
170
  )
171
 
 
172
  if __name__ == "__main__":
173
+ demo.launch()