import os import gradio as gr from typing import Callable import base64 from openai import OpenAI def get_fn(model_name: str, **model_kwargs): """Create a chat function that uses the OpenAI-compatible endpoint.""" OPENAI_API_KEY = "-" client = OpenAI( base_url=" http://192.222.58.60:8000/v1", api_key="tela", ) def predict( message: str, history, system_prompt: str, temperature: float, top_p: float, max_tokens: int, ): try: messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) for user_msg, assistant_msg in history: messages.append({"role": "user", "content": user_msg}) messages.append({"role": "assistant", "content": assistant_msg}) messages.append({"role": "user", "content": message}) response = client.chat.completions.create( model=model_name, messages=messages, temperature=temperature, top_p=top_p, max_tokens=max_new_tokens, n=1, stream=True, response_format={"type": "text"}, ) response_text = "" for chunk in response: chunk_message = chunk.choices[0].delta.content if chunk_message: response_text += chunk_message yield assistant_message.strip() except Exception as e: print(f"Error during generation: {str(e)}") yield f"An error occurred: {str(e)}" return predict def get_image_base64(url: str, ext: str): with open(url, "rb") as image_file: encoded_string = base64.b64encode(image_file.read()).decode('utf-8') return "data:image/" + ext + ";base64," + encoded_string def handle_user_msg(message: str): if isinstance(message, str): return message elif isinstance(message, dict): if message.get("files"): ext = os.path.splitext(message["files"][-1])[1].strip(".") if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]: encoded_str = get_image_base64(message["files"][-1], ext) else: raise NotImplementedError(f"Not supported file type {ext}") content = [ {"type": "text", "text": message.get("text", "")}, { "type": "image_url", "image_url": { "url": encoded_str, } }, ] else: content = message.get("text", "") return content else: raise NotImplementedError def get_model_path(name: str = None, model_path: str = None) -> str: """Get the model name to use with the endpoint.""" if model_path: return model_path if name: return name raise ValueError("Either name or model_path must be provided") def registry(name: str = None, model_path: str = None, **kwargs): """Create a Gradio ChatInterface.""" model_name = get_model_path(name, model_path) fn = get_fn(model_name, **kwargs) interface = gr.ChatInterface( fn=fn, additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False), additional_inputs=[ gr.Textbox( "You are a helpful AI assistant.", label="System prompt" ), gr.Slider(0, 1, 0.7, label="Temperature"), gr.Slider(128, 4096, 1024, label="Max new tokens"), gr.Slider(0, 1, 0.95, label="Top P sampling"), ], ) return interface