Apriel-Chat / utils.py
bradnow's picture
Add stop button and new styles
0bb4279
raw
history blame
3.7 kB
import os
import sys
from typing import Any, Literal
from gradio import ChatMessage
from gradio.components.chatbot import Message
from functools import wraps
COMMUNITY_POSTFIX_URL = "/discussions"
DEBUG_MODE = False or os.environ.get("DEBUG_MODE") == "True"
models_config = {
"Apriel-Nemotron-15b-Thinker": {
"MODEL_DISPLAY_NAME": "Apriel-Nemotron-15b-Thinker",
"MODEL_HF_URL": "https://huggingface.co/ServiceNow-AI/Apriel-Nemotron-15b-Thinker",
"MODEL_NAME": os.environ.get("MODEL_NAME_NEMO_15B"),
"VLLM_API_URL": os.environ.get("VLLM_API_URL_NEMO_15B"),
"AUTH_TOKEN": os.environ.get("AUTH_TOKEN"),
"REASONING": True
},
"Apriel-5b": {
"MODEL_DISPLAY_NAME": "Apriel-5b",
"MODEL_HF_URL": "https://huggingface.co/ServiceNow-AI/Apriel-5B-Instruct",
"MODEL_NAME": os.environ.get("MODEL_NAME_5B"),
"VLLM_API_URL": os.environ.get("VLLM_API_URL_5B"),
"AUTH_TOKEN": os.environ.get("AUTH_TOKEN"),
"REASONING": False
}
}
def get_model_config(model_name: str) -> dict:
config = models_config.get(model_name)
if not config:
raise ValueError(f"Model {model_name} not found in models_config")
if not config.get("MODEL_NAME"):
raise ValueError(f"Model name not found in config for {model_name}")
if not config.get("VLLM_API_URL"):
raise ValueError(f"VLLM API URL not found in config for {model_name}")
return config
def log_message(message):
if DEBUG_MODE is True:
print(f"≫≫≫ {message}")
# Gradio 5.0.1 had issues with checking the message formats. 5.29.0 does not!
def check_format(messages: Any, type: Literal["messages", "tuples"] = "messages") -> None:
if not DEBUG_MODE:
return
if type == "messages":
all_valid = all(
isinstance(message, dict)
and "role" in message
and "content" in message
or isinstance(message, ChatMessage | Message)
for message in messages
)
if not all_valid:
# Display which message is not valid
for i, message in enumerate(messages):
if not (isinstance(message, dict) and
"role" in message and
"content" in message) and not isinstance(message, ChatMessage | Message):
print(f"_check_format() --> Invalid message at index {i}: {message}\n", file=sys.stderr)
break
raise Exception(
"Data incompatible with messages format. Each message should be a dictionary with 'role' and 'content' keys or a ChatMessage object."
)
# else:
# print("_check_format() --> All messages are valid.")
elif not all(
isinstance(message, (tuple, list)) and len(message) == 2
for message in messages
):
raise Exception(
"Data incompatible with tuples format. Each message should be a list of length 2."
)
def logged_event_handler(log_message='', event_handler=None, timer=None, clear_timer=False):
@wraps(event_handler)
def wrapped_event_handler(*args, **kwargs):
# Log before
if timer:
if clear_timer:
timer.clear()
timer.add_step(f"Start: {log_message}")
log_message(f"::: Before event: {log_message}")
# Call the original event handler
result = event_handler(*args, **kwargs)
# Log after
if timer:
timer.add_step(f"Completed: {log_message}")
log_message(f"::: After event: {log_message}")
return result
return wrapped_event_handler