import os import gradio as gr from typing import Callable, Generator import base64 from openai import OpenAI def get_fn(model_name: str, **model_kwargs) -> Callable: """Create a chat function with the specified model.""" # Instantiate an OpenAI client for a custom endpoint try: client = OpenAI( base_url="http://192.222.58.60:8000/v1", api_key="tela", ) except Exception as e: print(f"The API or base URL were not defined: {str(e)}") raise e # It's better to raise the exception to prevent the app from running without a client def predict( message: str, history: list, system_prompt: str, temperature: float, max_tokens: int, top_k: int, repetition_penalty: float, top_p: float ) -> Generator[str, None, None]: try: # Initialize the messages list with the system prompt messages = [ {"role": "system", "content": system_prompt} ] # Append the conversation history for user_msg, assistant_msg in history: messages.append({"role": "user", "content": user_msg}) if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) # Append the latest user message messages.append({"role": "user", "content": message}) # Call the OpenAI API with the formatted messages response = client.chat.completions.create( model=model_name, messages=messages, temperature=temperature, max_tokens=max_tokens, top_k=top_k, repetition_penalty=repetition_penalty, top_p=top_p, stream=True, # Ensure response_format is set correctly; typically it's a string like 'text' response_format="text", ) response_text = "" # Iterate over the streaming response for chunk in response: if 'choices' in chunk and len(chunk['choices']) > 0: delta = chunk['choices'][0].get('delta', {}) content = delta.get('content', '') if content: response_text += content yield response_text.strip() if not response_text.strip(): yield "I apologize, but I was unable to generate a response. Please try again." except Exception as e: print(f"Error during generation: {str(e)}") yield f"An error occurred: {str(e)}" return predict def get_image_base64(url: str, ext: str): with open(url, "rb") as image_file: encoded_string = base64.b64encode(image_file.read()).decode('utf-8') return "data:image/" + ext + ";base64," + encoded_string def handle_user_msg(message: str): if type(message) is str: return message elif type(message) is dict: if message["files"] is not None and len(message["files"]) > 0: ext = os.path.splitext(message["files"][-1])[1].strip(".") if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]: encoded_str = get_image_base64(message["files"][-1], ext) else: raise NotImplementedError(f"Not supported file type {ext}") content = [ {"type": "text", "text": message["text"]}, { "type": "image_url", "image_url": { "url": encoded_str, } }, ] else: content = message["text"] return content else: raise NotImplementedError def get_interface_args(pipeline: str): if pipeline == "chat": inputs = None outputs = None def preprocess(message, history): messages = [] files = None for user_msg, assistant_msg in history: if assistant_msg is not None: messages.append({"role": "user", "content": handle_user_msg(user_msg)}) messages.append({"role": "assistant", "content": assistant_msg}) else: files = user_msg if isinstance(message, str) and files is not None: message = {"text": message, "files": files} elif isinstance(message, dict) and files is not None: if not message.get("files"): message["files"] = files messages.append({"role": "user", "content": handle_user_msg(message)}) return {"messages": messages} postprocess = lambda x: x # No additional postprocessing needed else: # Add other pipeline types when they are needed raise ValueError(f"Unsupported pipeline type: {pipeline}") return inputs, outputs, preprocess, postprocess def registry(name: str = None, **kwargs) -> gr.ChatInterface: """Create a Gradio Interface with similar styling and parameters.""" # Retrieve preprocess and postprocess functions _, _, preprocess, postprocess = get_interface_args("chat") # Get the predict function predict_fn = get_fn(model_name=name, **kwargs) # Define a wrapper function that integrates preprocessing and postprocessing def wrapper(message, history, system_prompt, temperature, max_tokens, top_k, repetition_penalty, top_p): # Preprocess the inputs preprocessed = preprocess(message, history) # Extract the preprocessed messages messages = preprocessed["messages"] # Call the predict function and generate the response response_generator = predict_fn( messages=messages, temperature=temperature, max_tokens=max_tokens, top_k=top_k, repetition_penalty=repetition_penalty, top_p=top_p ) # Collect the generated response response = "" for partial_response in response_generator: response = partial_response # Gradio will handle streaming yield response # Create the Gradio ChatInterface with the wrapper function interface = gr.ChatInterface( fn=wrapper, additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False), additional_inputs=[ gr.Textbox( value="You are a helpful AI assistant.", label="System prompt" ), gr.Slider(0.0, 1.0, value=0.7, label="Temperature"), gr.Slider(128, 4096, value=1024, label="Max new tokens"), gr.Slider(1, 80, value=40, step=1, label="Top K sampling"), gr.Slider(0.0, 2.0, value=1.1, label="Repetition penalty"), gr.Slider(0.0, 1.0, value=0.95, label="Top P sampling"), ], # Optionally, you can customize other ChatInterface parameters here ) return interface