File size: 4,267 Bytes
0d67078
 
2ce979c
f6b99ca
0d67078
 
 
 
 
 
 
f6b99ca
0d67078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ce979c
 
 
 
 
 
 
 
0d67078
2ce979c
 
 
 
 
0d67078
 
f6b99ca
2ce979c
 
 
 
 
f6b99ca
 
0d67078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bb4279
 
f6b99ca
 
0bb4279
 
 
f6b99ca
0bb4279
f6b99ca
2ce979c
 
0bb4279
 
 
 
 
f6b99ca
 
2ce979c
0bb4279
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import sys
import time
from functools import wraps
from typing import Any, Literal

from gradio import ChatMessage
from gradio.components.chatbot import Message

COMMUNITY_POSTFIX_URL = "/discussions"
DEBUG_MODE = False or os.environ.get("DEBUG_MODE") == "True"
DEBUG_MODEL = False or os.environ.get("DEBUG_MODEL") == "True"

models_config = {
    "Apriel-Nemotron-15b-Thinker": {
        "MODEL_DISPLAY_NAME": "Apriel-Nemotron-15b-Thinker",
        "MODEL_HF_URL": "https://huggingface.co/ServiceNow-AI/Apriel-Nemotron-15b-Thinker",
        "MODEL_NAME": os.environ.get("MODEL_NAME_NEMO_15B"),
        "VLLM_API_URL": os.environ.get("VLLM_API_URL_NEMO_15B"),
        "AUTH_TOKEN": os.environ.get("AUTH_TOKEN"),
        "REASONING": True
    },
    "Apriel-5b": {
        "MODEL_DISPLAY_NAME": "Apriel-5b",
        "MODEL_HF_URL": "https://huggingface.co/ServiceNow-AI/Apriel-5B-Instruct",
        "MODEL_NAME": os.environ.get("MODEL_NAME_5B"),
        "VLLM_API_URL": os.environ.get("VLLM_API_URL_5B"),
        "AUTH_TOKEN": os.environ.get("AUTH_TOKEN"),
        "REASONING": False
    }
}


def get_model_config(model_name: str) -> dict:
    config = models_config.get(model_name)
    if not config:
        raise ValueError(f"Model {model_name} not found in models_config")
    if not config.get("MODEL_NAME"):
        raise ValueError(f"Model name not found in config for {model_name}")
    if not config.get("VLLM_API_URL"):
        raise ValueError(f"VLLM API URL not found in config for {model_name}")

    return config


def _log_message(prefix, message, icon=""):
    timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    if len(icon) > 0:
        icon = f"{icon} "
    print(f"{timestamp}: {prefix} {icon}{message}")


def log_debug(message):
    if DEBUG_MODE is True:
        _log_message("DEBUG", message)


def log_info(message):
    _log_message("INFO ", message)


def log_warning(message):
    _log_message("WARN ", message, "⚠️")


def log_error(message):
    _log_message("ERROR", message, "‼️")


# Gradio 5.0.1 had issues with checking the message formats.  5.29.0 does not!
def check_format(messages: Any, type: Literal["messages", "tuples"] = "messages") -> None:
    if not DEBUG_MODE:
        return

    if type == "messages":
        all_valid = all(
            isinstance(message, dict)
            and "role" in message
            and "content" in message
            or isinstance(message, ChatMessage | Message)
            for message in messages
        )
        if not all_valid:
            # Display which message is not valid
            for i, message in enumerate(messages):
                if not (isinstance(message, dict) and
                        "role" in message and
                        "content" in message) and not isinstance(message, ChatMessage | Message):
                    print(f"_check_format() --> Invalid message at index {i}: {message}\n", file=sys.stderr)
                    break

            raise Exception(
                "Data incompatible with messages format. Each message should be a dictionary with 'role' and 'content' keys or a ChatMessage object."
            )
        # else:
        #     print("_check_format() --> All messages are valid.")
    elif not all(
            isinstance(message, (tuple, list)) and len(message) == 2
            for message in messages
    ):
        raise Exception(
            "Data incompatible with tuples format. Each message should be a list of length 2."
        )


# Adds timing info for a gradio event handler (non-generator functions)
def logged_event_handler(log_msg='', event_handler=None, log_timer=None, clear_timer=False):
    @wraps(event_handler)
    def wrapped_event_handler(*args, **kwargs):
        # Log before
        if log_timer:
            if clear_timer:
                log_timer.clear()
            log_timer.add_step(f"Start: {log_debug}")
        log_debug(f"::: Before event: {log_msg}")

        # Call the original event handler
        result = event_handler(*args, **kwargs)

        # Log after
        if log_timer:
            log_timer.add_step(f"Completed: {log_msg}")
        log_debug(f"::: After event: {log_msg}")

        return result

    return wrapped_event_handler