Spaces:
Sleeping
Sleeping
Update helper.py
Browse files
helper.py
CHANGED
@@ -1,74 +1,82 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
-
from typing import Callable
|
4 |
import base64
|
5 |
from openai import OpenAI
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
def get_fn(model_path: str, **model_kwargs):
|
10 |
"""Create a chat function with the specified model."""
|
11 |
|
12 |
-
#
|
13 |
try:
|
14 |
-
OPENAI_API_KEY = "-"
|
15 |
client = OpenAI(
|
16 |
-
|
17 |
-
|
18 |
)
|
19 |
-
|
20 |
except Exception as e:
|
21 |
-
print(f"The
|
22 |
-
|
23 |
|
24 |
def predict(
|
25 |
message: str,
|
26 |
-
history,
|
27 |
system_prompt: str,
|
28 |
temperature: float,
|
29 |
max_tokens: int,
|
30 |
top_k: int,
|
31 |
repetition_penalty: float,
|
32 |
top_p: float
|
33 |
-
):
|
34 |
try:
|
35 |
-
#
|
36 |
-
|
|
|
|
|
|
|
|
|
37 |
for user_msg, assistant_msg in history:
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
41 |
response = client.chat.completions.create(
|
42 |
-
model=model_name,
|
43 |
messages=messages,
|
44 |
temperature=temperature,
|
45 |
max_tokens=max_tokens,
|
46 |
top_k=top_k,
|
47 |
repetition_penalty=repetition_penalty,
|
48 |
-
|
49 |
stream=True,
|
50 |
-
response_format
|
|
|
51 |
)
|
52 |
-
|
53 |
response_text = ""
|
|
|
54 |
for chunk in response:
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
if not response_text.strip():
|
63 |
yield "I apologize, but I was unable to generate a response. Please try again."
|
64 |
-
|
65 |
except Exception as e:
|
66 |
print(f"Error during generation: {str(e)}")
|
67 |
yield f"An error occurred: {str(e)}"
|
68 |
-
|
69 |
return predict
|
70 |
|
71 |
|
|
|
72 |
def get_image_base64(url: str, ext: str):
|
73 |
with open(url, "rb") as image_file:
|
74 |
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
@@ -101,7 +109,7 @@ def handle_user_msg(message: str):
|
|
101 |
raise NotImplementedError
|
102 |
|
103 |
|
104 |
-
def get_interface_args(pipeline):
|
105 |
if pipeline == "chat":
|
106 |
inputs = None
|
107 |
outputs = None
|
@@ -115,47 +123,71 @@ def get_interface_args(pipeline):
|
|
115 |
messages.append({"role": "assistant", "content": assistant_msg})
|
116 |
else:
|
117 |
files = user_msg
|
118 |
-
if
|
119 |
-
message = {"text":message, "files":files}
|
120 |
-
elif
|
121 |
-
if message
|
122 |
message["files"] = files
|
123 |
messages.append({"role": "user", "content": handle_user_msg(message)})
|
124 |
return {"messages": messages}
|
125 |
|
126 |
-
postprocess = lambda x: x
|
|
|
127 |
else:
|
128 |
-
# Add other pipeline types when they
|
129 |
raise ValueError(f"Unsupported pipeline type: {pipeline}")
|
130 |
return inputs, outputs, preprocess, postprocess
|
131 |
|
132 |
|
133 |
-
def
|
134 |
-
# Determine the pipeline type based on the model name
|
135 |
-
# For simplicity, assuming all models are chat models at the moment
|
136 |
-
return "chat"
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
def registry(name: str = None, **kwargs):
|
141 |
"""Create a Gradio Interface with similar styling and parameters."""
|
142 |
|
143 |
-
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
interface = gr.ChatInterface(
|
146 |
-
fn=
|
147 |
additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False),
|
148 |
additional_inputs=[
|
149 |
gr.Textbox(
|
150 |
-
"You are a helpful AI assistant.",
|
151 |
label="System prompt"
|
152 |
),
|
153 |
-
gr.Slider(0, 1, 0.7, label="Temperature"),
|
154 |
-
gr.Slider(128, 4096, 1024, label="Max new tokens"),
|
155 |
-
gr.Slider(1, 80, 40, label="Top K sampling"),
|
156 |
-
gr.Slider(0, 2, 1.1, label="Repetition penalty"),
|
157 |
-
gr.Slider(0, 1, 0.95, label="Top P sampling"),
|
158 |
],
|
|
|
159 |
)
|
160 |
|
161 |
-
return interface
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
+
from typing import Callable, Generator
|
4 |
import base64
|
5 |
from openai import OpenAI
|
6 |
|
7 |
+
def get_fn(model_name: str, **model_kwargs) -> Callable:
|
|
|
|
|
8 |
"""Create a chat function with the specified model."""
|
9 |
|
10 |
+
# Instantiate an OpenAI client for a custom endpoint
|
11 |
try:
|
|
|
12 |
client = OpenAI(
|
13 |
+
base_url="http://192.222.58.60:8000/v1",
|
14 |
+
api_key="tela",
|
15 |
)
|
|
|
16 |
except Exception as e:
|
17 |
+
print(f"The API or base URL were not defined: {str(e)}")
|
18 |
+
raise e # It's better to raise the exception to prevent the app from running without a client
|
19 |
|
20 |
def predict(
|
21 |
message: str,
|
22 |
+
history: list,
|
23 |
system_prompt: str,
|
24 |
temperature: float,
|
25 |
max_tokens: int,
|
26 |
top_k: int,
|
27 |
repetition_penalty: float,
|
28 |
top_p: float
|
29 |
+
) -> Generator[str, None, None]:
|
30 |
try:
|
31 |
+
# Initialize the messages list with the system prompt
|
32 |
+
messages = [
|
33 |
+
{"role": "system", "content": system_prompt}
|
34 |
+
]
|
35 |
+
|
36 |
+
# Append the conversation history
|
37 |
for user_msg, assistant_msg in history:
|
38 |
+
messages.append({"role": "user", "content": user_msg})
|
39 |
+
if assistant_msg:
|
40 |
+
messages.append({"role": "assistant", "content": assistant_msg})
|
41 |
+
|
42 |
+
# Append the latest user message
|
43 |
+
messages.append({"role": "user", "content": message})
|
44 |
+
|
45 |
+
# Call the OpenAI API with the formatted messages
|
46 |
response = client.chat.completions.create(
|
47 |
+
model=model_name,
|
48 |
messages=messages,
|
49 |
temperature=temperature,
|
50 |
max_tokens=max_tokens,
|
51 |
top_k=top_k,
|
52 |
repetition_penalty=repetition_penalty,
|
53 |
+
top_p=top_p,
|
54 |
stream=True,
|
55 |
+
# Ensure response_format is set correctly; typically it's a string like 'text'
|
56 |
+
response_format="text",
|
57 |
)
|
58 |
+
|
59 |
response_text = ""
|
60 |
+
# Iterate over the streaming response
|
61 |
for chunk in response:
|
62 |
+
if 'choices' in chunk and len(chunk['choices']) > 0:
|
63 |
+
delta = chunk['choices'][0].get('delta', {})
|
64 |
+
content = delta.get('content', '')
|
65 |
+
if content:
|
66 |
+
response_text += content
|
67 |
+
yield response_text.strip()
|
68 |
+
|
69 |
if not response_text.strip():
|
70 |
yield "I apologize, but I was unable to generate a response. Please try again."
|
71 |
+
|
72 |
except Exception as e:
|
73 |
print(f"Error during generation: {str(e)}")
|
74 |
yield f"An error occurred: {str(e)}"
|
75 |
+
|
76 |
return predict
|
77 |
|
78 |
|
79 |
+
|
80 |
def get_image_base64(url: str, ext: str):
|
81 |
with open(url, "rb") as image_file:
|
82 |
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
|
|
109 |
raise NotImplementedError
|
110 |
|
111 |
|
112 |
+
def get_interface_args(pipeline: str):
|
113 |
if pipeline == "chat":
|
114 |
inputs = None
|
115 |
outputs = None
|
|
|
123 |
messages.append({"role": "assistant", "content": assistant_msg})
|
124 |
else:
|
125 |
files = user_msg
|
126 |
+
if isinstance(message, str) and files is not None:
|
127 |
+
message = {"text": message, "files": files}
|
128 |
+
elif isinstance(message, dict) and files is not None:
|
129 |
+
if not message.get("files"):
|
130 |
message["files"] = files
|
131 |
messages.append({"role": "user", "content": handle_user_msg(message)})
|
132 |
return {"messages": messages}
|
133 |
|
134 |
+
postprocess = lambda x: x # No additional postprocessing needed
|
135 |
+
|
136 |
else:
|
137 |
+
# Add other pipeline types when they are needed
|
138 |
raise ValueError(f"Unsupported pipeline type: {pipeline}")
|
139 |
return inputs, outputs, preprocess, postprocess
|
140 |
|
141 |
|
142 |
+
def registry(name: str = None, **kwargs) -> gr.ChatInterface:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
"""Create a Gradio Interface with similar styling and parameters."""
|
144 |
|
145 |
+
# Retrieve preprocess and postprocess functions
|
146 |
+
_, _, preprocess, postprocess = get_interface_args("chat")
|
147 |
+
|
148 |
+
# Get the predict function
|
149 |
+
predict_fn = get_fn(model_path=name, **kwargs)
|
150 |
+
|
151 |
+
# Define a wrapper function that integrates preprocessing and postprocessing
|
152 |
+
def wrapper(message, history, system_prompt, temperature, max_tokens, top_k, repetition_penalty, top_p):
|
153 |
+
# Preprocess the inputs
|
154 |
+
preprocessed = preprocess(message, history)
|
155 |
+
|
156 |
+
# Extract the preprocessed messages
|
157 |
+
messages = preprocessed["messages"]
|
158 |
+
|
159 |
+
# Call the predict function and generate the response
|
160 |
+
response_generator = predict_fn(
|
161 |
+
messages=messages,
|
162 |
+
temperature=temperature,
|
163 |
+
max_tokens=max_tokens,
|
164 |
+
top_k=top_k,
|
165 |
+
repetition_penalty=repetition_penalty,
|
166 |
+
top_p=top_p
|
167 |
+
)
|
168 |
+
|
169 |
+
# Collect the generated response
|
170 |
+
response = ""
|
171 |
+
for partial_response in response_generator:
|
172 |
+
response = partial_response # Gradio will handle streaming
|
173 |
+
yield response
|
174 |
+
|
175 |
+
# Create the Gradio ChatInterface with the wrapper function
|
176 |
interface = gr.ChatInterface(
|
177 |
+
fn=wrapper,
|
178 |
additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False),
|
179 |
additional_inputs=[
|
180 |
gr.Textbox(
|
181 |
+
value="You are a helpful AI assistant.",
|
182 |
label="System prompt"
|
183 |
),
|
184 |
+
gr.Slider(0.0, 1.0, value=0.7, label="Temperature"),
|
185 |
+
gr.Slider(128, 4096, value=1024, label="Max new tokens"),
|
186 |
+
gr.Slider(1, 80, value=40, step=1, label="Top K sampling"),
|
187 |
+
gr.Slider(0.0, 2.0, value=1.1, label="Repetition penalty"),
|
188 |
+
gr.Slider(0.0, 1.0, value=0.95, label="Top P sampling"),
|
189 |
],
|
190 |
+
# Optionally, you can customize other ChatInterface parameters here
|
191 |
)
|
192 |
|
193 |
+
return interface
|