Spaces:
Sleeping
Sleeping
File size: 7,257 Bytes
87c4b82 543fed2 87c4b82 ee40bdf 87c4b82 543fed2 6a93de9 543fed2 6a93de9 543fed2 6a93de9 543fed2 87c4b82 543fed2 87c4b82 7e72b19 6a93de9 543fed2 87c4b82 543fed2 87c4b82 543fed2 ccfb364 543fed2 ee40bdf 87c4b82 2ebb338 7e72b19 543fed2 ee40bdf 543fed2 87c4b82 543fed2 b5fc8ee 543fed2 b5fc8ee 543fed2 6a93de9 543fed2 87c4b82 543fed2 87c4b82 6a93de9 543fed2 87c4b82 6a93de9 87c4b82 6a93de9 87c4b82 6a93de9 87c4b82 6a93de9 87c4b82 6a93de9 87c4b82 543fed2 6a93de9 543fed2 6a93de9 543fed2 6a93de9 543fed2 6a93de9 543fed2 6a93de9 543fed2 6bf705e 543fed2 87c4b82 543fed2 87c4b82 543fed2 87c4b82 543fed2 87c4b82 543fed2 87c4b82 6a93de9 543fed2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import os
import gradio as gr
from typing import Callable, Generator
import base64
from openai import OpenAI
def get_fn(model_name: str, **model_kwargs) -> Callable:
"""Create a chat function with the specified model."""
# Instantiate an OpenAI client for a custom endpoint
try:
client = OpenAI(
base_url="http://192.222.58.60:8000/v1",
api_key="tela",
)
except Exception as e:
print(f"The API or base URL were not defined: {str(e)}")
raise e # It's better to raise the exception to prevent the app from running without a client
def predict(
message: str,
history: list,
system_prompt: str,
temperature: float,
max_tokens: int,
top_k: int,
repetition_penalty: float,
top_p: float
) -> Generator[str, None, None]:
try:
# Initialize the messages list with the system prompt
messages = [
{"role": "system", "content": system_prompt}
]
# Append the conversation history
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
# Append the latest user message
messages.append({"role": "user", "content": message})
# Call the OpenAI API with the formatted messages
response = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_k=top_k,
repetition_penalty=repetition_penalty,
top_p=top_p,
stream=True,
# Ensure response_format is set correctly; typically it's a string like 'text'
response_format="text",
)
response_text = ""
# Iterate over the streaming response
for chunk in response:
if 'choices' in chunk and len(chunk['choices']) > 0:
delta = chunk['choices'][0].get('delta', {})
content = delta.get('content', '')
if content:
response_text += content
yield response_text.strip()
if not response_text.strip():
yield "I apologize, but I was unable to generate a response. Please try again."
except Exception as e:
print(f"Error during generation: {str(e)}")
yield f"An error occurred: {str(e)}"
return predict
def get_image_base64(url: str, ext: str):
with open(url, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
return "data:image/" + ext + ";base64," + encoded_string
def handle_user_msg(message: str):
if type(message) is str:
return message
elif type(message) is dict:
if message["files"] is not None and len(message["files"]) > 0:
ext = os.path.splitext(message["files"][-1])[1].strip(".")
if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]:
encoded_str = get_image_base64(message["files"][-1], ext)
else:
raise NotImplementedError(f"Not supported file type {ext}")
content = [
{"type": "text", "text": message["text"]},
{
"type": "image_url",
"image_url": {
"url": encoded_str,
}
},
]
else:
content = message["text"]
return content
else:
raise NotImplementedError
def get_interface_args(pipeline: str):
if pipeline == "chat":
inputs = None
outputs = None
def preprocess(message, history):
messages = []
files = None
for user_msg, assistant_msg in history:
if assistant_msg is not None:
messages.append({"role": "user", "content": handle_user_msg(user_msg)})
messages.append({"role": "assistant", "content": assistant_msg})
else:
files = user_msg
if isinstance(message, str) and files is not None:
message = {"text": message, "files": files}
elif isinstance(message, dict) and files is not None:
if not message.get("files"):
message["files"] = files
messages.append({"role": "user", "content": handle_user_msg(message)})
return {"messages": messages}
postprocess = lambda x: x # No additional postprocessing needed
else:
# Add other pipeline types when they are needed
raise ValueError(f"Unsupported pipeline type: {pipeline}")
return inputs, outputs, preprocess, postprocess
def registry(name: str = None, **kwargs) -> gr.ChatInterface:
"""Create a Gradio Interface with similar styling and parameters."""
# Retrieve preprocess and postprocess functions
_, _, preprocess, postprocess = get_interface_args("chat")
# Get the predict function
predict_fn = get_fn(model_name=name, **kwargs)
# Define a wrapper function that integrates preprocessing and postprocessing
def wrapper(message, history, system_prompt, temperature, max_tokens, top_k, repetition_penalty, top_p):
# Preprocess the inputs
preprocessed = preprocess(message, history)
# Extract the preprocessed messages
messages = preprocessed["messages"]
# Call the predict function and generate the response
response_generator = predict_fn(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_k=top_k,
repetition_penalty=repetition_penalty,
top_p=top_p
)
# Collect the generated response
response = ""
for partial_response in response_generator:
response = partial_response # Gradio will handle streaming
yield response
# Create the Gradio ChatInterface with the wrapper function
interface = gr.ChatInterface(
fn=wrapper,
additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False),
additional_inputs=[
gr.Textbox(
value="You are a helpful AI assistant.",
label="System prompt"
),
gr.Slider(0.0, 1.0, value=0.7, label="Temperature"),
gr.Slider(128, 4096, value=1024, label="Max new tokens"),
gr.Slider(1, 80, value=40, step=1, label="Top K sampling"),
gr.Slider(0.0, 2.0, value=1.1, label="Repetition penalty"),
gr.Slider(0.0, 1.0, value=0.95, label="Top P sampling"),
],
# Optionally, you can customize other ChatInterface parameters here
)
return interface
|