marco-o1 / helper.py
rodrigomasini's picture
Update helper.py
6bf705e verified
raw
history blame
7.26 kB
import os
import gradio as gr
from typing import Callable, Generator
import base64
from openai import OpenAI
def get_fn(model_name: str, **model_kwargs) -> Callable:
"""Create a chat function with the specified model."""
# Instantiate an OpenAI client for a custom endpoint
try:
client = OpenAI(
base_url="http://192.222.58.60:8000/v1",
api_key="tela",
)
except Exception as e:
print(f"The API or base URL were not defined: {str(e)}")
raise e # It's better to raise the exception to prevent the app from running without a client
def predict(
message: str,
history: list,
system_prompt: str,
temperature: float,
max_tokens: int,
top_k: int,
repetition_penalty: float,
top_p: float
) -> Generator[str, None, None]:
try:
# Initialize the messages list with the system prompt
messages = [
{"role": "system", "content": system_prompt}
]
# Append the conversation history
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
# Append the latest user message
messages.append({"role": "user", "content": message})
# Call the OpenAI API with the formatted messages
response = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_k=top_k,
repetition_penalty=repetition_penalty,
top_p=top_p,
stream=True,
# Ensure response_format is set correctly; typically it's a string like 'text'
response_format="text",
)
response_text = ""
# Iterate over the streaming response
for chunk in response:
if 'choices' in chunk and len(chunk['choices']) > 0:
delta = chunk['choices'][0].get('delta', {})
content = delta.get('content', '')
if content:
response_text += content
yield response_text.strip()
if not response_text.strip():
yield "I apologize, but I was unable to generate a response. Please try again."
except Exception as e:
print(f"Error during generation: {str(e)}")
yield f"An error occurred: {str(e)}"
return predict
def get_image_base64(url: str, ext: str):
with open(url, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
return "data:image/" + ext + ";base64," + encoded_string
def handle_user_msg(message: str):
if type(message) is str:
return message
elif type(message) is dict:
if message["files"] is not None and len(message["files"]) > 0:
ext = os.path.splitext(message["files"][-1])[1].strip(".")
if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]:
encoded_str = get_image_base64(message["files"][-1], ext)
else:
raise NotImplementedError(f"Not supported file type {ext}")
content = [
{"type": "text", "text": message["text"]},
{
"type": "image_url",
"image_url": {
"url": encoded_str,
}
},
]
else:
content = message["text"]
return content
else:
raise NotImplementedError
def get_interface_args(pipeline: str):
if pipeline == "chat":
inputs = None
outputs = None
def preprocess(message, history):
messages = []
files = None
for user_msg, assistant_msg in history:
if assistant_msg is not None:
messages.append({"role": "user", "content": handle_user_msg(user_msg)})
messages.append({"role": "assistant", "content": assistant_msg})
else:
files = user_msg
if isinstance(message, str) and files is not None:
message = {"text": message, "files": files}
elif isinstance(message, dict) and files is not None:
if not message.get("files"):
message["files"] = files
messages.append({"role": "user", "content": handle_user_msg(message)})
return {"messages": messages}
postprocess = lambda x: x # No additional postprocessing needed
else:
# Add other pipeline types when they are needed
raise ValueError(f"Unsupported pipeline type: {pipeline}")
return inputs, outputs, preprocess, postprocess
def registry(name: str = None, **kwargs) -> gr.ChatInterface:
"""Create a Gradio Interface with similar styling and parameters."""
# Retrieve preprocess and postprocess functions
_, _, preprocess, postprocess = get_interface_args("chat")
# Get the predict function
predict_fn = get_fn(model_name=name, **kwargs)
# Define a wrapper function that integrates preprocessing and postprocessing
def wrapper(message, history, system_prompt, temperature, max_tokens, top_k, repetition_penalty, top_p):
# Preprocess the inputs
preprocessed = preprocess(message, history)
# Extract the preprocessed messages
messages = preprocessed["messages"]
# Call the predict function and generate the response
response_generator = predict_fn(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_k=top_k,
repetition_penalty=repetition_penalty,
top_p=top_p
)
# Collect the generated response
response = ""
for partial_response in response_generator:
response = partial_response # Gradio will handle streaming
yield response
# Create the Gradio ChatInterface with the wrapper function
interface = gr.ChatInterface(
fn=wrapper,
additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False),
additional_inputs=[
gr.Textbox(
value="You are a helpful AI assistant.",
label="System prompt"
),
gr.Slider(0.0, 1.0, value=0.7, label="Temperature"),
gr.Slider(128, 4096, value=1024, label="Max new tokens"),
gr.Slider(1, 80, value=40, step=1, label="Top K sampling"),
gr.Slider(0.0, 2.0, value=1.1, label="Repetition penalty"),
gr.Slider(0.0, 1.0, value=0.95, label="Top P sampling"),
],
# Optionally, you can customize other ChatInterface parameters here
)
return interface