File size: 5,866 Bytes
31bd9a2
9497fc8
eefa003
9497fc8
31bd9a2
9497fc8
 
 
 
 
 
31bd9a2
eefa003
 
4d9f591
 
9e185d2
a220efd
31bd9a2
eefa003
 
9e185d2
31bd9a2
 
 
 
 
 
 
 
 
9e185d2
 
651a6e5
 
 
eefa003
 
 
 
 
 
9497fc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eefa003
 
9497fc8
 
 
 
 
 
 
 
 
31bd9a2
eefa003
 
 
 
 
 
 
9497fc8
eefa003
9497fc8
9e185d2
 
 
 
 
eefa003
9497fc8
9e185d2
9497fc8
eefa003
9497fc8
31bd9a2
 
 
9e185d2
9497fc8
31bd9a2
 
 
 
9e185d2
 
 
 
 
eefa003
9497fc8
9e185d2
31bd9a2
9e185d2
31bd9a2
 
 
 
9e185d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9497fc8
9e185d2
31bd9a2
eefa003
9497fc8
 
31bd9a2
 
 
 
9e185d2
31bd9a2
 
 
 
 
 
9e185d2
31bd9a2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import sys
import datetime

from openai import OpenAI
import gradio as gr
from gradio.components.chatbot import ChatMessage, Message
from typing import (
    Any,
    Literal,
)

DEBUG_LOG = False or os.environ.get("DEBUG_LOG") == "True"

print(f"Gradio version: {gr.__version__}")

title = None  # "ServiceNow-AI Chat" # modelConfig.get('MODE_DISPLAY_NAME')
description = "Please use the community section on this space to provide feedback! <a href=\"https://huggingface.co/ServiceNow-AI/Apriel-Nemotron-15b-Thinker/discussions\">ServiceNow-AI/Apriel-Nemotron-Chat</a>"

chat_start_count = 0

model_config = {
    "MODEL_NAME": os.environ.get("MODEL_NAME"),
    "MODE_DISPLAY_NAME": os.environ.get("MODE_DISPLAY_NAME"),
    "MODEL_HF_URL": os.environ.get("MODEL_HF_URL"),
    "VLLM_API_URL": os.environ.get("VLLM_API_URL"),
    "AUTH_TOKEN": os.environ.get("AUTH_TOKEN")
}

# Initialize the OpenAI client with the vLLM API URL and token
client = OpenAI(
    api_key=model_config.get('AUTH_TOKEN'),
    base_url=model_config.get('VLLM_API_URL')
)


def log_message(message):
    if DEBUG_LOG is True:
        print(message)


# Gradio 5.0.1 had issues with checking the message formats.  5.29.0 does not!
def _check_format(messages: Any, type: Literal["messages", "tuples"] = "messages") -> None:
    if type == "messages":
        all_valid = all(
            isinstance(message, dict)
            and "role" in message
            and "content" in message
            or isinstance(message, ChatMessage | Message)
            for message in messages
        )
        if not all_valid:
            # Display which message is not valid
            for i, message in enumerate(messages):
                if not (isinstance(message, dict) and
                        "role" in message and
                        "content" in message) and not isinstance(message, ChatMessage | Message):
                    print(f"_check_format() --> Invalid message at index {i}: {message}\n", file=sys.stderr)
                    break

            raise Exception(
                "Data incompatible with messages format. Each message should be a dictionary with 'role' and 'content' keys or a ChatMessage object."
            )
        # else:
        #     print("_check_format() --> All messages are valid.")
    elif not all(
            isinstance(message, (tuple, list)) and len(message) == 2
            for message in messages
    ):
        raise Exception(
            "Data incompatible with tuples format. Each message should be a list of length 2."
        )


def chat_fn(message, history):
    log_message(f"{'-' * 80}\nchat_fn() --> Message: {message}")

    global chat_start_count
    chat_start_count = chat_start_count + 1
    print(
        f"{datetime.datetime.now()}: chat_start_count: {chat_start_count}, turns: {int(len(history if history else []) / 3)}")

    # Remove any assistant messages with metadata from history for multiple turns
    log_message(f"Original History: {history}")
    _check_format(history, "messages")
    history = [item for item in history if
               not (isinstance(item, dict) and
                    item.get("role") == "assistant" and
                    isinstance(item.get("metadata"), dict) and
                    item.get("metadata", {}).get("title") is not None)]
    log_message(f"Updated History: {history}")
    _check_format(history, "messages")

    history.append({"role": "user", "content": message})
    log_message(f"History with user message: {history}")
    _check_format(history, "messages")

    # Create the streaming response
    stream = client.chat.completions.create(
        model=model_config.get('MODEL_NAME'),
        messages=history,
        temperature=0.8,
        stream=True
    )

    history.append(gr.ChatMessage(
        role="assistant",
        content="Thinking...",
        metadata={"title": "🧠 Thought"}
    ))
    log_message(f"History added thinking: {history}")
    _check_format(history, "messages")

    output = ""
    completion_started = False
    for chunk in stream:
        # Extract the new content from the delta field
        content = getattr(chunk.choices[0].delta, "content", "")
        output += content

        parts = output.split("[BEGIN FINAL RESPONSE]")

        if len(parts) > 1:
            if parts[1].endswith("[END FINAL RESPONSE]"):
                parts[1] = parts[1].replace("[END FINAL RESPONSE]", "")
            if parts[1].endswith("[END FINAL RESPONSE]\n<|end|>"):
                parts[1] = parts[1].replace("[END FINAL RESPONSE]\n<|end|>", "")

        history[-1 if not completion_started else -2] = gr.ChatMessage(
            role="assistant",
            content=parts[0],
            metadata={"title": "🧠 Thought"}
        )
        if completion_started:
            history[-1] = gr.ChatMessage(
                role="assistant",
                content=parts[1]
            )
        elif len(parts) > 1 and not completion_started:
            completion_started = True
            history.append(gr.ChatMessage(
                role="assistant",
                content=parts[1]
            ))

        # only yield the most recent assistant messages
        messages_to_yield = history[-1:] if not completion_started else history[-2:]
        # _check_format(messages_to_yield, "messages")
        yield messages_to_yield

    log_message(f"Final History: {history}")
    _check_format(history, "messages")


# Add the model display name and Hugging Face URL to the description
# description = f"### Model: [{MODE_DISPLAY_NAME}]({MODEL_HF_URL})"

print(f"Running model {model_config.get('MODE_DISPLAY_NAME')} ({model_config.get('MODEL_NAME')})")

gr.ChatInterface(
    chat_fn,
    title=title,
    description=description,
    theme=gr.themes.Default(primary_hue="green"),
    type="messages",
).launch()