File size: 7,257 Bytes
87c4b82
 
543fed2
87c4b82
ee40bdf
87c4b82
543fed2
6a93de9
 
543fed2
6a93de9
 
543fed2
 
6a93de9
 
543fed2
 
87c4b82
 
 
543fed2
87c4b82
 
7e72b19
6a93de9
 
 
543fed2
87c4b82
543fed2
 
 
 
 
 
87c4b82
543fed2
 
 
 
 
 
 
 
ccfb364
543fed2
ee40bdf
87c4b82
2ebb338
7e72b19
 
543fed2
ee40bdf
543fed2
 
87c4b82
543fed2
b5fc8ee
543fed2
b5fc8ee
543fed2
 
 
 
 
 
 
6a93de9
 
543fed2
87c4b82
 
 
543fed2
87c4b82
 
6a93de9
543fed2
87c4b82
 
 
 
 
6a93de9
87c4b82
6a93de9
87c4b82
6a93de9
 
87c4b82
 
 
 
 
 
6a93de9
 
 
 
 
 
 
 
87c4b82
6a93de9
87c4b82
 
 
 
 
543fed2
6a93de9
 
 
 
 
 
 
 
 
 
 
 
 
543fed2
 
 
 
6a93de9
 
 
 
543fed2
 
6a93de9
543fed2
6a93de9
 
 
 
543fed2
6a93de9
 
543fed2
 
 
 
6bf705e
543fed2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87c4b82
543fed2
87c4b82
 
 
543fed2
87c4b82
 
543fed2
 
 
 
 
87c4b82
543fed2
87c4b82
6a93de9
543fed2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import os
import gradio as gr
from typing import Callable, Generator
import base64
from openai import OpenAI

def get_fn(model_name: str, **model_kwargs) -> Callable:
    """Create a chat function with the specified model."""
    
    # Instantiate an OpenAI client for a custom endpoint
    try:
        client = OpenAI(
            base_url="http://192.222.58.60:8000/v1", 
            api_key="tela",  
        )
    except Exception as e:
        print(f"The API or base URL were not defined: {str(e)}")
        raise e  # It's better to raise the exception to prevent the app from running without a client

    def predict(
        message: str,
        history: list,
        system_prompt: str,
        temperature: float,
        max_tokens: int,
        top_k: int,
        repetition_penalty: float,
        top_p: float
    ) -> Generator[str, None, None]:
        try:
            # Initialize the messages list with the system prompt
            messages = [
                {"role": "system", "content": system_prompt}
            ]
            
            # Append the conversation history
            for user_msg, assistant_msg in history:
                messages.append({"role": "user", "content": user_msg})
                if assistant_msg:
                    messages.append({"role": "assistant", "content": assistant_msg})
            
            # Append the latest user message
            messages.append({"role": "user", "content": message})
            
            # Call the OpenAI API with the formatted messages
            response = client.chat.completions.create(
                model=model_name,  
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
                top_k=top_k,
                repetition_penalty=repetition_penalty,
                top_p=top_p,
                stream=True, 
                # Ensure response_format is set correctly; typically it's a string like 'text'
                response_format="text",
            )
    
            response_text = ""
            # Iterate over the streaming response
            for chunk in response:
                if 'choices' in chunk and len(chunk['choices']) > 0:
                    delta = chunk['choices'][0].get('delta', {})
                    content = delta.get('content', '')
                    if content:
                        response_text += content
                        yield response_text.strip()
            
            if not response_text.strip():
                yield "I apologize, but I was unable to generate a response. Please try again."
    
        except Exception as e:
            print(f"Error during generation: {str(e)}")
            yield f"An error occurred: {str(e)}"
    
    return predict



def get_image_base64(url: str, ext: str):
    with open(url, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
    return "data:image/" + ext + ";base64," + encoded_string


def handle_user_msg(message: str):
    if type(message) is str:
        return message
    elif type(message) is dict:
        if message["files"] is not None and len(message["files"]) > 0:
            ext = os.path.splitext(message["files"][-1])[1].strip(".")
            if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]:
                encoded_str = get_image_base64(message["files"][-1], ext)
            else:
                raise NotImplementedError(f"Not supported file type {ext}")
            content = [
                    {"type": "text", "text": message["text"]},
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": encoded_str,
                        }
                    },
                ]
        else:
            content = message["text"]
        return content
    else:
        raise NotImplementedError


def get_interface_args(pipeline: str):
    if pipeline == "chat":
        inputs = None
        outputs = None

        def preprocess(message, history):
            messages = []
            files = None
            for user_msg, assistant_msg in history:
                if assistant_msg is not None:
                    messages.append({"role": "user", "content": handle_user_msg(user_msg)})
                    messages.append({"role": "assistant", "content": assistant_msg})
                else:
                    files = user_msg
            if isinstance(message, str) and files is not None:
                message = {"text": message, "files": files}
            elif isinstance(message, dict) and files is not None:
                if not message.get("files"):
                    message["files"] = files
            messages.append({"role": "user", "content": handle_user_msg(message)})
            return {"messages": messages}

        postprocess = lambda x: x  # No additional postprocessing needed

    else:
        # Add other pipeline types when they are needed
        raise ValueError(f"Unsupported pipeline type: {pipeline}")
    return inputs, outputs, preprocess, postprocess


def registry(name: str = None, **kwargs) -> gr.ChatInterface:
    """Create a Gradio Interface with similar styling and parameters."""
    
    # Retrieve preprocess and postprocess functions
    _, _, preprocess, postprocess = get_interface_args("chat")
    
    # Get the predict function
    predict_fn = get_fn(model_name=name, **kwargs)
    
    # Define a wrapper function that integrates preprocessing and postprocessing
    def wrapper(message, history, system_prompt, temperature, max_tokens, top_k, repetition_penalty, top_p):
        # Preprocess the inputs
        preprocessed = preprocess(message, history)
        
        # Extract the preprocessed messages
        messages = preprocessed["messages"]
        
        # Call the predict function and generate the response
        response_generator = predict_fn(
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            top_k=top_k,
            repetition_penalty=repetition_penalty,
            top_p=top_p
        )
        
        # Collect the generated response
        response = ""
        for partial_response in response_generator:
            response = partial_response  # Gradio will handle streaming
            yield response

    # Create the Gradio ChatInterface with the wrapper function
    interface = gr.ChatInterface(
        fn=wrapper,
        additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False),
        additional_inputs=[
            gr.Textbox(
                value="You are a helpful AI assistant.",
                label="System prompt"
            ),
            gr.Slider(0.0, 1.0, value=0.7, label="Temperature"),
            gr.Slider(128, 4096, value=1024, label="Max new tokens"),
            gr.Slider(1, 80, value=40, step=1, label="Top K sampling"),
            gr.Slider(0.0, 2.0, value=1.1, label="Repetition penalty"),
            gr.Slider(0.0, 1.0, value=0.95, label="Top P sampling"),
        ],
        # Optionally, you can customize other ChatInterface parameters here
    )
    
    return interface