Spaces:
Sleeping
Sleeping
File size: 5,898 Bytes
87c4b82 543fed2 87c4b82 ee40bdf 87c4b82 543fed2 6a93de9 543fed2 6a93de9 543fed2 612a10c 6a93de9 543fed2 612a10c 87c4b82 612a10c 87c4b82 7e72b19 6a93de9 543fed2 87c4b82 543fed2 ccfb364 543fed2 ee40bdf 87c4b82 2ebb338 543fed2 ee40bdf 543fed2 0663556 87c4b82 612a10c b5fc8ee 543fed2 b5fc8ee 40abfbf 5747d32 543fed2 6a93de9 612a10c 87c4b82 6a93de9 612a10c 543fed2 612a10c 87c4b82 612a10c 6a93de9 612a10c 87c4b82 612a10c 87c4b82 612a10c 87c4b82 612a10c 87c4b82 612a10c 87c4b82 612a10c 87c4b82 543fed2 6a93de9 543fed2 6a93de9 543fed2 6a93de9 543fed2 6a93de9 543fed2 6a93de9 543fed2 6bf705e 543fed2 fdb64df 543fed2 62124a7 543fed2 87c4b82 543fed2 87c4b82 543fed2 87c4b82 543fed2 87c4b82 6a93de9 543fed2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
import gradio as gr
from typing import Callable, Generator
import base64
from openai import OpenAI
def get_fn(model_name: str, **model_kwargs) -> Callable:
"""Create a chat function with the specified model."""
# Instantiate an OpenAI client for a custom endpoint
try:
client = OpenAI(
base_url="http://192.222.58.60:8000/v1",
api_key="tela", # Replace with your actual API key or use environment variables
)
except Exception as e:
print(f"The API or base URL were not defined: {str(e)}")
raise e # Prevent the app from running without a client
def predict(
messages: list, # Preprocessed messages from preprocess function
temperature: float,
max_tokens: int,
top_p: float
) -> Generator[str, None, None]:
try:
# Call the OpenAI API with the formatted messages
response = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stream=True,
# Ensure response_format is set correctly; typically it's a string like 'text'
response_format={"type": "text"},
)
response_text = ""
# Iterate over the streaming response
for chunk in response:
if len(chunk.choices[0].delta.content) > 0:
content = chunk.choices[0].delta.content
if content:
response_text += content
yield response_text.strip()
if not response_text.strip():
yield "I apologize, but I was unable to generate a response. Please try again."
except Exception as e:
print(f"Error during generation: {str(e)}")
yield f"An error occurred: {str(e)}"
return predict
def get_image_base64(url: str, ext: str) -> str:
with open(url, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
return f"data:image/{ext};base64,{encoded_string}"
def handle_user_msg(message: str) -> str:
if isinstance(message, str):
return message
elif isinstance(message, dict):
if message.get("files"):
ext = os.path.splitext(message["files"][-1])[1].strip(".").lower()
if ext in ["png", "jpg", "jpeg", "gif", "pdf"]:
encoded_str = get_image_base64(message["files"][-1], ext)
return f"{message.get('text', '')}\n"
else:
raise NotImplementedError(f"Unsupported file type: {ext}")
else:
return message.get("text", "")
else:
raise NotImplementedError("Unsupported message type")
def get_interface_args(pipeline: str):
if pipeline == "chat":
inputs = None
outputs = None
def preprocess(message, history):
messages = []
files = None
for user_msg, assistant_msg in history:
if assistant_msg is not None:
messages.append({"role": "user", "content": handle_user_msg(user_msg)})
messages.append({"role": "assistant", "content": assistant_msg})
else:
files = user_msg
if isinstance(message, str) and files is not None:
message = {"text": message, "files": files}
elif isinstance(message, dict) and files is not None:
if not message.get("files"):
message["files"] = files
messages.append({"role": "user", "content": handle_user_msg(message)})
return {"messages": messages}
postprocess = lambda x: x # No additional postprocessing needed
else:
# Add other pipeline types when they are needed
raise ValueError(f"Unsupported pipeline type: {pipeline}")
return inputs, outputs, preprocess, postprocess
def registry(name: str = None, **kwargs) -> gr.ChatInterface:
"""Create a Gradio Interface with similar styling and parameters."""
# Retrieve preprocess and postprocess functions
_, _, preprocess, postprocess = get_interface_args("chat")
# Get the predict function
predict_fn = get_fn(model_name=name, **kwargs)
# Define a wrapper function that integrates preprocessing and postprocessing
def wrapper(message, history, system_prompt, temperature, max_tokens, top_p):
# Preprocess the inputs
preprocessed = preprocess(message, history)
# Extract the preprocessed messages
messages = preprocessed["messages"]
# Call the predict function and generate the response
response_generator = predict_fn(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p
)
# Collect the generated response
response = ""
for partial_response in response_generator:
response = partial_response # Gradio will handle streaming
yield response
# Create the Gradio ChatInterface with the wrapper function
interface = gr.ChatInterface(
fn=wrapper,
additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False),
additional_inputs=[
gr.Textbox(
value="You are a helpful AI assistant.",
label="System prompt"
),
gr.Slider(0.0, 1.0, value=0.7, label="Temperature"),
gr.Slider(128, 4096, value=1024, label="Max new tokens"),
gr.Slider(0.0, 1.0, value=0.95, label="Top P sampling"),
],
)
return interface
|