|
import gradio as gr |
|
from openai import OpenAI |
|
import os |
|
|
|
|
|
ACCESS_TOKEN = os.getenv("HF_TOKEN") |
|
print("Access token loaded.") |
|
|
|
|
|
client = OpenAI( |
|
base_url="https://api-inference.huggingface.co/v1/", |
|
api_key=ACCESS_TOKEN, |
|
) |
|
print("OpenAI client initialized.") |
|
|
|
def respond( |
|
user_message, |
|
chat_history, |
|
system_msg, |
|
max_tokens, |
|
temperature, |
|
top_p, |
|
frequency_penalty, |
|
seed, |
|
featured_model, |
|
custom_model |
|
): |
|
""" |
|
This function handles the chatbot response. It takes in: |
|
- user_message: the user's newly typed message |
|
- chat_history: the list of (user, assistant) message pairs |
|
- system_msg: the system instruction or system-level context |
|
- max_tokens: the maximum number of tokens to generate |
|
- temperature: sampling temperature |
|
- top_p: top-p (nucleus) sampling |
|
- frequency_penalty: penalize repeated tokens in the output |
|
- seed: a fixed seed for reproducibility; -1 means 'random' |
|
- featured_model: the chosen model name from 'Featured Models' radio |
|
- custom_model: the optional custom model that overrides the featured one if provided |
|
""" |
|
|
|
print(f"Received user message: {user_message}") |
|
print(f"System message: {system_msg}") |
|
print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}, Freq-Penalty: {frequency_penalty}, Seed: {seed}") |
|
print(f"Featured model: {featured_model}") |
|
print(f"Custom model: {custom_model}") |
|
|
|
|
|
if seed == -1: |
|
seed = None |
|
|
|
|
|
|
|
model_to_use = custom_model.strip() if custom_model.strip() != "" else featured_model |
|
|
|
if model_to_use.strip() == "": |
|
model_to_use = "meta-llama/Llama-3.3-70B-Instruct" |
|
|
|
print(f"Model selected for inference: {model_to_use}") |
|
|
|
|
|
messages = [] |
|
if system_msg.strip(): |
|
messages.append({"role": "system", "content": system_msg.strip()}) |
|
|
|
|
|
for user_text, assistant_text in chat_history: |
|
if user_text: |
|
messages.append({"role": "user", "content": user_text}) |
|
if assistant_text: |
|
messages.append({"role": "assistant", "content": assistant_text}) |
|
|
|
|
|
messages.append({"role": "user", "content": user_message}) |
|
|
|
|
|
response_so_far = "" |
|
print("Sending request to the Hugging Face Inference API...") |
|
|
|
|
|
try: |
|
for resp_chunk in client.chat.completions.create( |
|
model=model_to_use, |
|
max_tokens=max_tokens, |
|
stream=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
frequency_penalty=frequency_penalty, |
|
seed=seed, |
|
messages=messages, |
|
): |
|
token_text = resp_chunk.choices[0].delta.content |
|
response_so_far += token_text |
|
|
|
yield response_so_far |
|
except Exception as e: |
|
|
|
error_text = f"[ERROR] {str(e)}" |
|
print(error_text) |
|
yield response_so_far + "\n\n" + error_text |
|
|
|
print("Completed response generation.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
models_list = [ |
|
"meta-llama/Llama-3.3-70B-Instruct", |
|
"meta-llama/Llama-2-13B-chat-hf", |
|
"bigscience/bloom", |
|
"openlm-research/open_llama_7b", |
|
"facebook/opt-6.7b", |
|
"google/flan-t5-xxl", |
|
] |
|
|
|
def filter_models(search_term): |
|
"""Filters the models_list by the given search_term and returns an update for the Radio component.""" |
|
filtered = [m for m in models_list if search_term.lower() in m.lower()] |
|
return gr.update(choices=filtered) |
|
|
|
with gr.Blocks(theme="Nymbo/Nymbo_Theme_5") as demo: |
|
gr.Markdown("# Serverless-TextGen-Hub (Enhanced)") |
|
gr.Markdown("**A comprehensive UI for text generation with a featured-models dropdown and a custom override**.") |
|
|
|
|
|
chat_history = gr.State([]) |
|
|
|
|
|
with gr.Tab("Basic Settings"): |
|
with gr.Row(): |
|
with gr.Column(elem_id="prompt-container"): |
|
|
|
system_msg = gr.Textbox( |
|
label="System message", |
|
placeholder="Enter system-level instructions or context here.", |
|
lines=2 |
|
) |
|
|
|
with gr.Accordion("Featured Models", open=True): |
|
model_search = gr.Textbox( |
|
label="Filter Models", |
|
placeholder="Search for a featured model...", |
|
lines=1 |
|
) |
|
|
|
model_radio = gr.Radio( |
|
label="Select a featured model below", |
|
choices=models_list, |
|
value=models_list[0], |
|
interactive=True |
|
) |
|
|
|
model_search.change(filter_models, inputs=model_search, outputs=model_radio) |
|
|
|
|
|
custom_model_box = gr.Textbox( |
|
label="Custom Model (Optional)", |
|
info="If provided, overrides the featured model above. e.g. 'meta-llama/Llama-3.3-70B-Instruct'", |
|
placeholder="Your huggingface.co/username/model_name path" |
|
) |
|
|
|
with gr.Tab("Advanced Settings"): |
|
with gr.Row(): |
|
max_tokens_slider = gr.Slider( |
|
minimum=1, |
|
maximum=4096, |
|
value=512, |
|
step=1, |
|
label="Max new tokens" |
|
) |
|
temperature_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=4.0, |
|
value=0.7, |
|
step=0.1, |
|
label="Temperature" |
|
) |
|
top_p_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.95, |
|
step=0.05, |
|
label="Top-P" |
|
) |
|
with gr.Row(): |
|
freq_penalty_slider = gr.Slider( |
|
minimum=-2.0, |
|
maximum=2.0, |
|
value=0.0, |
|
step=0.1, |
|
label="Frequency Penalty" |
|
) |
|
seed_slider = gr.Slider( |
|
minimum=-1, |
|
maximum=65535, |
|
value=-1, |
|
step=1, |
|
label="Seed (-1 for random)" |
|
) |
|
|
|
|
|
with gr.Row(): |
|
chatbot = gr.Chatbot( |
|
label="TextGen Chat", |
|
height=500 |
|
) |
|
|
|
|
|
user_input = gr.Textbox( |
|
label="Your message", |
|
placeholder="Type your text prompt here..." |
|
) |
|
|
|
|
|
send_button = gr.Button("Send") |
|
|
|
|
|
clear_button = gr.Button("Clear Chat") |
|
|
|
|
|
def user_submission(user_text, history): |
|
""" |
|
This function gets called first to add the user's message to the chat. |
|
We return the updated chat_history with the user's message appended, |
|
plus an empty string for the next user input box. |
|
""" |
|
if user_text.strip() == "": |
|
return history, "" |
|
|
|
history = history + [(user_text, None)] |
|
return history, "" |
|
|
|
send_button.click( |
|
fn=user_submission, |
|
inputs=[user_input, chat_history], |
|
outputs=[chat_history, user_input] |
|
) |
|
|
|
|
|
def bot_response( |
|
history, |
|
system_msg, |
|
max_tokens, |
|
temperature, |
|
top_p, |
|
freq_penalty, |
|
seed, |
|
featured_model, |
|
custom_model |
|
): |
|
""" |
|
This function is called to generate the assistant's response |
|
based on the conversation so far, system message, etc. |
|
We do the streaming here. |
|
""" |
|
if not history: |
|
yield history |
|
|
|
user_message = history[-1][0] if history else "" |
|
|
|
bot_stream = respond( |
|
user_message=user_message, |
|
chat_history=history[:-1], |
|
system_msg=system_msg, |
|
max_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
frequency_penalty=freq_penalty, |
|
seed=seed, |
|
featured_model=featured_model, |
|
custom_model=custom_model |
|
) |
|
partial_text = "" |
|
for partial_text in bot_stream: |
|
|
|
updated_history = history[:-1] + [(history[-1][0], partial_text)] |
|
yield updated_history |
|
|
|
send_button.click( |
|
fn=bot_response, |
|
inputs=[ |
|
chat_history, |
|
system_msg, |
|
max_tokens_slider, |
|
temperature_slider, |
|
top_p_slider, |
|
freq_penalty_slider, |
|
seed_slider, |
|
model_radio, |
|
custom_model_box |
|
], |
|
outputs=chatbot |
|
) |
|
|
|
|
|
def clear_chat(): |
|
return [], "" |
|
|
|
clear_button.click( |
|
fn=clear_chat, |
|
inputs=[], |
|
outputs=[chat_history, user_input] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
print("Launching the Serverless-TextGen-Hub with Featured Models & Custom Model override.") |
|
demo.launch() |