File size: 5,323 Bytes
31bd9a2
9497fc8
 
31bd9a2
9497fc8
 
 
 
 
 
 
 
 
 
31bd9a2
9e185d2
31bd9a2
 
9e185d2
31bd9a2
 
 
 
 
 
 
 
 
9e185d2
 
651a6e5
 
 
9497fc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31bd9a2
9497fc8
 
9e185d2
9497fc8
9e185d2
 
 
 
 
 
9497fc8
9e185d2
9497fc8
 
 
 
 
 
 
31bd9a2
 
 
9e185d2
9497fc8
31bd9a2
 
 
 
9e185d2
 
 
 
 
9497fc8
 
9e185d2
31bd9a2
9e185d2
31bd9a2
 
 
 
9e185d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9497fc8
9e185d2
31bd9a2
9497fc8
 
 
31bd9a2
 
 
 
9e185d2
31bd9a2
 
 
 
 
 
9e185d2
31bd9a2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import sys

from openai import OpenAI
import gradio as gr
from gradio.components.chatbot import ChatMessage, Message
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    Optional,
    Union,
    cast,
)

title = None  # "ServiceNow-AI Chat" # modelConfig.get('MODE_DISPLAY_NAME')
description = None

model_config = {
    "MODEL_NAME": os.environ.get("MODEL_NAME"),
    "MODE_DISPLAY_NAME": os.environ.get("MODE_DISPLAY_NAME"),
    "MODEL_HF_URL": os.environ.get("MODEL_HF_URL"),
    "VLLM_API_URL": os.environ.get("VLLM_API_URL"),
    "AUTH_TOKEN": os.environ.get("AUTH_TOKEN")
}

# Initialize the OpenAI client with the vLLM API URL and token
client = OpenAI(
    api_key=model_config.get('AUTH_TOKEN'),
    base_url=model_config.get('VLLM_API_URL')
)


def _check_format(messages: Any, type: Literal["messages", "tuples"] = "messages") -> None:
    if type == "messages":
        all_valid = all(
            isinstance(message, dict)
            and "role" in message
            and "content" in message
            or isinstance(message, ChatMessage | Message)
            for message in messages
        )
        if not all_valid:
            # Display which message is not valid
            for i, message in enumerate(messages):
                if not (isinstance(message, dict) and
                        "role" in message and
                        "content" in message) and not isinstance(message, ChatMessage | Message):
                    print(f"_check_format() --> Invalid message at index {i}: {message}\n", file=sys.stderr)
                    break

            raise Exception(
                "Data incompatible with messages format. Each message should be a dictionary with 'role' and 'content' keys or a ChatMessage object."
            )
        else:
            print("_check_format() --> All messages are valid.")
    elif not all(
            isinstance(message, (tuple, list)) and len(message) == 2
            for message in messages
    ):
        raise Exception(
            "Data incompatible with tuples format. Each message should be a list of length 2."
        )


def chat_fn(message, history):
    print(f"{'-' * 80}\nchat_fn() --> Message: {message}")
    # Remove any assistant messages with metadata from history for multiple turns
    print(f"Original History: {history}")
    _check_format(history, "messages")
    history = [item for item in history if
               not (isinstance(item, dict) and
                    item.get("role") == "assistant" and
                    isinstance(item.get("metadata"), dict) and
                    item.get("metadata", {}).get("title") is not None)]
    print(f"Updated History: {history}")
    _check_format(history, "messages")

    # messages = history + [{"role": "user", "content": message}]
    # print(f"Messages: {messages}")
    # _check_format(messages, "messages")

    history.append({"role": "user", "content": message})
    print(f"History with user message: {history}")
    _check_format(history, "messages")

    # Create the streaming response
    stream = client.chat.completions.create(
        model=model_config.get('MODEL_NAME'),
        messages=history,
        temperature=0.8,
        stream=True
    )

    history.append(gr.ChatMessage(
        role="assistant",
        content="Thinking...",
        metadata={"title": "🧠 Thought"}
    ))
    print(f"History added thinking: {history}")
    _check_format(history, "messages")

    output = ""
    completion_started = False
    for chunk in stream:
        # Extract the new content from the delta field
        content = getattr(chunk.choices[0].delta, "content", "")
        output += content

        parts = output.split("[BEGIN FINAL RESPONSE]")

        if len(parts) > 1:
            if parts[1].endswith("[END FINAL RESPONSE]"):
                parts[1] = parts[1].replace("[END FINAL RESPONSE]", "")
            if parts[1].endswith("[END FINAL RESPONSE]\n<|end|>"):
                parts[1] = parts[1].replace("[END FINAL RESPONSE]\n<|end|>", "")

        history[-1 if not completion_started else -2] = gr.ChatMessage(
            role="assistant",
            content=parts[0],
            metadata={"title": "🧠 Thought"}
        )
        if completion_started:
            history[-1] = gr.ChatMessage(
                role="assistant",
                content=parts[1]
            )
        elif len(parts) > 1 and not completion_started:
            completion_started = True
            history.append(gr.ChatMessage(
                role="assistant",
                content=parts[1]
            ))

        # only yield the most recent assistant messages
        messages_to_yield = history[-1:] if not completion_started else history[-2:]
        # _check_format(messages_to_yield, "messages")
        yield messages_to_yield

    print(f"Final History: {history}")
    _check_format(history, "messages")


# Add the model display name and Hugging Face URL to the description
# description = f"### Model: [{MODE_DISPLAY_NAME}]({MODEL_HF_URL})"

print(f"Running model {model_config.get('MODE_DISPLAY_NAME')} ({model_config.get('MODEL_NAME')})")

gr.ChatInterface(
    chat_fn,
    title=title,
    description=description,
    theme=gr.themes.Default(primary_hue="green"),
    type="messages",
).launch()