rodrigomasini commited on
Commit
543fed2
·
verified ·
1 Parent(s): 7e72b19

Update helper.py

Browse files
Files changed (1) hide show
  1. helper.py +88 -56
helper.py CHANGED
@@ -1,74 +1,82 @@
1
  import os
2
  import gradio as gr
3
- from typing import Callable
4
  import base64
5
  from openai import OpenAI
6
 
7
-
8
-
9
- def get_fn(model_path: str, **model_kwargs):
10
  """Create a chat function with the specified model."""
11
 
12
- # instatiate a OpenAI client for a custom endpoint
13
  try:
14
- OPENAI_API_KEY = "-"
15
  client = OpenAI(
16
- base_url=" http://192.222.58.60:8000/v1",
17
- api_key="tela",
18
  )
19
-
20
  except Exception as e:
21
- print(f"The api or base url were not definied: {str(e)}")
22
-
23
 
24
  def predict(
25
  message: str,
26
- history,
27
  system_prompt: str,
28
  temperature: float,
29
  max_tokens: int,
30
  top_k: int,
31
  repetition_penalty: float,
32
  top_p: float
33
- ):
34
  try:
35
- # Format conversation with ChatML format
36
- instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
 
 
 
 
37
  for user_msg, assistant_msg in history:
38
- instruction += f'<|im_start|>user\n{user_msg}\n<|im_end|>\n<|im_start|>assistant\n{assistant_msg}\n<|im_end|>\n'
39
- instruction += f'<|im_start|>user\n{message}\n<|im_end|>\n<|im_start|>assistant\n'
40
-
 
 
 
 
 
41
  response = client.chat.completions.create(
42
- model=model_name,
43
  messages=messages,
44
  temperature=temperature,
45
  max_tokens=max_tokens,
46
  top_k=top_k,
47
  repetition_penalty=repetition_penalty,
48
- n=1,
49
  stream=True,
50
- response_format={"type": "text"},
 
51
  )
52
-
53
  response_text = ""
 
54
  for chunk in response:
55
- streamer = chunk.choices[0].delta.content
56
- for new_token in streamer:
57
- if new_token in ["<|endoftext|>", "<|im_end|>"]:
58
- break
59
- response_text += new_token
60
- yield response_text.strip()
61
-
62
  if not response_text.strip():
63
  yield "I apologize, but I was unable to generate a response. Please try again."
64
-
65
  except Exception as e:
66
  print(f"Error during generation: {str(e)}")
67
  yield f"An error occurred: {str(e)}"
68
-
69
  return predict
70
 
71
 
 
72
  def get_image_base64(url: str, ext: str):
73
  with open(url, "rb") as image_file:
74
  encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
@@ -101,7 +109,7 @@ def handle_user_msg(message: str):
101
  raise NotImplementedError
102
 
103
 
104
- def get_interface_args(pipeline):
105
  if pipeline == "chat":
106
  inputs = None
107
  outputs = None
@@ -115,47 +123,71 @@ def get_interface_args(pipeline):
115
  messages.append({"role": "assistant", "content": assistant_msg})
116
  else:
117
  files = user_msg
118
- if type(message) is str and files is not None:
119
- message = {"text":message, "files":files}
120
- elif type(message) is dict and files is not None:
121
- if message["files"] is None or len(message["files"]) == 0:
122
  message["files"] = files
123
  messages.append({"role": "user", "content": handle_user_msg(message)})
124
  return {"messages": messages}
125
 
126
- postprocess = lambda x: x
 
127
  else:
128
- # Add other pipeline types when they will be needed
129
  raise ValueError(f"Unsupported pipeline type: {pipeline}")
130
  return inputs, outputs, preprocess, postprocess
131
 
132
 
133
- def get_pipeline(model_name):
134
- # Determine the pipeline type based on the model name
135
- # For simplicity, assuming all models are chat models at the moment
136
- return "chat"
137
-
138
-
139
-
140
- def registry(name: str = None, **kwargs):
141
  """Create a Gradio Interface with similar styling and parameters."""
142
 
143
- fn = get_fn(name, **kwargs)
144
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  interface = gr.ChatInterface(
146
- fn=fn,
147
  additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False),
148
  additional_inputs=[
149
  gr.Textbox(
150
- "You are a helpful AI assistant.",
151
  label="System prompt"
152
  ),
153
- gr.Slider(0, 1, 0.7, label="Temperature"),
154
- gr.Slider(128, 4096, 1024, label="Max new tokens"),
155
- gr.Slider(1, 80, 40, label="Top K sampling"),
156
- gr.Slider(0, 2, 1.1, label="Repetition penalty"),
157
- gr.Slider(0, 1, 0.95, label="Top P sampling"),
158
  ],
 
159
  )
160
 
161
- return interface
 
1
  import os
2
  import gradio as gr
3
+ from typing import Callable, Generator
4
  import base64
5
  from openai import OpenAI
6
 
7
+ def get_fn(model_name: str, **model_kwargs) -> Callable:
 
 
8
  """Create a chat function with the specified model."""
9
 
10
+ # Instantiate an OpenAI client for a custom endpoint
11
  try:
 
12
  client = OpenAI(
13
+ base_url="http://192.222.58.60:8000/v1",
14
+ api_key="tela",
15
  )
 
16
  except Exception as e:
17
+ print(f"The API or base URL were not defined: {str(e)}")
18
+ raise e # It's better to raise the exception to prevent the app from running without a client
19
 
20
  def predict(
21
  message: str,
22
+ history: list,
23
  system_prompt: str,
24
  temperature: float,
25
  max_tokens: int,
26
  top_k: int,
27
  repetition_penalty: float,
28
  top_p: float
29
+ ) -> Generator[str, None, None]:
30
  try:
31
+ # Initialize the messages list with the system prompt
32
+ messages = [
33
+ {"role": "system", "content": system_prompt}
34
+ ]
35
+
36
+ # Append the conversation history
37
  for user_msg, assistant_msg in history:
38
+ messages.append({"role": "user", "content": user_msg})
39
+ if assistant_msg:
40
+ messages.append({"role": "assistant", "content": assistant_msg})
41
+
42
+ # Append the latest user message
43
+ messages.append({"role": "user", "content": message})
44
+
45
+ # Call the OpenAI API with the formatted messages
46
  response = client.chat.completions.create(
47
+ model=model_name,
48
  messages=messages,
49
  temperature=temperature,
50
  max_tokens=max_tokens,
51
  top_k=top_k,
52
  repetition_penalty=repetition_penalty,
53
+ top_p=top_p,
54
  stream=True,
55
+ # Ensure response_format is set correctly; typically it's a string like 'text'
56
+ response_format="text",
57
  )
58
+
59
  response_text = ""
60
+ # Iterate over the streaming response
61
  for chunk in response:
62
+ if 'choices' in chunk and len(chunk['choices']) > 0:
63
+ delta = chunk['choices'][0].get('delta', {})
64
+ content = delta.get('content', '')
65
+ if content:
66
+ response_text += content
67
+ yield response_text.strip()
68
+
69
  if not response_text.strip():
70
  yield "I apologize, but I was unable to generate a response. Please try again."
71
+
72
  except Exception as e:
73
  print(f"Error during generation: {str(e)}")
74
  yield f"An error occurred: {str(e)}"
75
+
76
  return predict
77
 
78
 
79
+
80
  def get_image_base64(url: str, ext: str):
81
  with open(url, "rb") as image_file:
82
  encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
 
109
  raise NotImplementedError
110
 
111
 
112
+ def get_interface_args(pipeline: str):
113
  if pipeline == "chat":
114
  inputs = None
115
  outputs = None
 
123
  messages.append({"role": "assistant", "content": assistant_msg})
124
  else:
125
  files = user_msg
126
+ if isinstance(message, str) and files is not None:
127
+ message = {"text": message, "files": files}
128
+ elif isinstance(message, dict) and files is not None:
129
+ if not message.get("files"):
130
  message["files"] = files
131
  messages.append({"role": "user", "content": handle_user_msg(message)})
132
  return {"messages": messages}
133
 
134
+ postprocess = lambda x: x # No additional postprocessing needed
135
+
136
  else:
137
+ # Add other pipeline types when they are needed
138
  raise ValueError(f"Unsupported pipeline type: {pipeline}")
139
  return inputs, outputs, preprocess, postprocess
140
 
141
 
142
+ def registry(name: str = None, **kwargs) -> gr.ChatInterface:
 
 
 
 
 
 
 
143
  """Create a Gradio Interface with similar styling and parameters."""
144
 
145
+ # Retrieve preprocess and postprocess functions
146
+ _, _, preprocess, postprocess = get_interface_args("chat")
147
+
148
+ # Get the predict function
149
+ predict_fn = get_fn(model_path=name, **kwargs)
150
+
151
+ # Define a wrapper function that integrates preprocessing and postprocessing
152
+ def wrapper(message, history, system_prompt, temperature, max_tokens, top_k, repetition_penalty, top_p):
153
+ # Preprocess the inputs
154
+ preprocessed = preprocess(message, history)
155
+
156
+ # Extract the preprocessed messages
157
+ messages = preprocessed["messages"]
158
+
159
+ # Call the predict function and generate the response
160
+ response_generator = predict_fn(
161
+ messages=messages,
162
+ temperature=temperature,
163
+ max_tokens=max_tokens,
164
+ top_k=top_k,
165
+ repetition_penalty=repetition_penalty,
166
+ top_p=top_p
167
+ )
168
+
169
+ # Collect the generated response
170
+ response = ""
171
+ for partial_response in response_generator:
172
+ response = partial_response # Gradio will handle streaming
173
+ yield response
174
+
175
+ # Create the Gradio ChatInterface with the wrapper function
176
  interface = gr.ChatInterface(
177
+ fn=wrapper,
178
  additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False),
179
  additional_inputs=[
180
  gr.Textbox(
181
+ value="You are a helpful AI assistant.",
182
  label="System prompt"
183
  ),
184
+ gr.Slider(0.0, 1.0, value=0.7, label="Temperature"),
185
+ gr.Slider(128, 4096, value=1024, label="Max new tokens"),
186
+ gr.Slider(1, 80, value=40, step=1, label="Top K sampling"),
187
+ gr.Slider(0.0, 2.0, value=1.1, label="Repetition penalty"),
188
+ gr.Slider(0.0, 1.0, value=0.95, label="Top P sampling"),
189
  ],
190
+ # Optionally, you can customize other ChatInterface parameters here
191
  )
192
 
193
+ return interface