|
import os |
|
import gradio as gr |
|
from huggingface_hub import InferenceClient |
|
from sqlalchemy import create_engine |
|
|
|
|
|
llama_client = InferenceClient(provider="sambanova", api_key=os.environ["HF_TOKEN"]) |
|
minimax_client = InferenceClient(provider="novita", api_key=os.environ["HF_TOKEN"]) |
|
mistral_client = InferenceClient(provider="together", api_key=os.environ["HF_TOKEN"]) |
|
|
|
|
|
db_connection = None |
|
|
|
def get_sqlalchemy_connection(): |
|
server = os.getenv("SQL_SERVER") |
|
database = os.getenv("SQL_DATABASE") |
|
username = os.getenv("SQL_USERNAME") |
|
password = os.getenv("SQL_PASSWORD") |
|
|
|
connection_url = f"mssql+pymssql://{username}:{password}@{server}/{database}" |
|
|
|
try: |
|
engine = create_engine(connection_url) |
|
conn = engine.connect() |
|
print("β
SQLAlchemy + pymssql connection successful") |
|
return conn |
|
except Exception as e: |
|
print(f"β SQLAlchemy connection failed: {e}") |
|
return None |
|
|
|
def get_sql_connection(): |
|
global db_connection |
|
|
|
if db_connection is not None: |
|
try: |
|
db_connection.cursor() |
|
return db_connection |
|
except Exception as e: |
|
print(f"β SQL connection failed: {e}") |
|
db_connection = None |
|
|
|
|
|
db_connection = get_sqlalchemy_connection() |
|
return db_connection |
|
|
|
|
|
def format_chat_history(chat_history): |
|
formatted = "" |
|
for msg in chat_history: |
|
role = msg["role"] |
|
content = msg["content"] |
|
if isinstance(content, list): |
|
for item in content: |
|
if "text" in item: |
|
formatted += f"**{role.capitalize()}:** {item['text']}\n\n" |
|
elif "image_url" in item: |
|
formatted += f"**{role.capitalize()}:** πΌοΈ Image: {item['image_url']['url']}\n\n" |
|
else: |
|
formatted += f"**{role.capitalize()}:** {content}\n\n" |
|
return formatted.strip() |
|
|
|
|
|
def chat_with_model(model_choice, prompt, image_url, chat_history): |
|
if not prompt: |
|
return "β Please enter a text prompt.", chat_history, "", "" |
|
|
|
if chat_history is None: |
|
chat_history = [] |
|
|
|
conn = get_sql_connection() |
|
if conn is None: |
|
return "β Failed to connect to database.", chat_history, "", "" |
|
|
|
try: |
|
|
|
if model_choice == "LLaMA 4 (SambaNova)": |
|
user_msg = [{"type": "text", "text": prompt}] |
|
if image_url: |
|
user_msg.append({"type": "image_url", "image_url": {"url": image_url}}) |
|
chat_history.append({"role": "user", "content": user_msg}) |
|
|
|
response = llama_client.chat.completions.create( |
|
model="meta-llama/Llama-4-Maverick-17B-128E-Instruct", |
|
messages=chat_history |
|
) |
|
bot_msg = response.choices[0].message.content |
|
chat_history.append({"role": "assistant", "content": bot_msg}) |
|
|
|
|
|
elif model_choice == "MiniMax M1 (Novita)": |
|
chat_history.append({"role": "user", "content": prompt}) |
|
response = minimax_client.chat.completions.create( |
|
model="MiniMaxAI/MiniMax-M1-80k", |
|
messages=chat_history |
|
) |
|
bot_msg = response.choices[0].message.content |
|
chat_history.append({"role": "assistant", "content": bot_msg}) |
|
|
|
|
|
elif model_choice == "Mistral Mixtral-8x7B (Together)": |
|
chat_history.append({"role": "user", "content": prompt}) |
|
response = mistral_client.chat.completions.create( |
|
model="mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
messages=chat_history |
|
) |
|
bot_msg = response.choices[0].message.content |
|
chat_history.append({"role": "assistant", "content": bot_msg}) |
|
|
|
else: |
|
return "β Unsupported model selected.", chat_history, "", "" |
|
|
|
return format_chat_history(chat_history), chat_history, "", "" |
|
|
|
except Exception as e: |
|
return f"β Error: {e}", chat_history, "", "" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## π€ Multi-Model Context-Aware Chatbot") |
|
gr.Markdown("Supports LLaMA 4 (with optional image), MiniMax, and Mistral. Memory is preserved for multi-turn dialog.") |
|
|
|
model_dropdown = gr.Dropdown( |
|
choices=[ |
|
"LLaMA 4 (SambaNova)", |
|
"MiniMax M1 (Novita)", |
|
"Mistral Mixtral-8x7B (Together)" |
|
], |
|
value="LLaMA 4 (SambaNova)", |
|
label="Select Model" |
|
) |
|
|
|
prompt_input = gr.Textbox(label="Text Prompt", placeholder="Ask something...", lines=2) |
|
image_url_input = gr.Textbox(label="Optional Image URL (for LLaMA only)", placeholder="https://example.com/image.jpg") |
|
|
|
submit_btn = gr.Button("π¬ Generate Response") |
|
reset_btn = gr.Button("π Reset Conversation") |
|
output_box = gr.Markdown(label="Chat History", value="") |
|
state = gr.State([]) |
|
|
|
submit_btn.click( |
|
fn=chat_with_model, |
|
inputs=[model_dropdown, prompt_input, image_url_input, state], |
|
outputs=[output_box, state, prompt_input, image_url_input] |
|
) |
|
|
|
reset_btn.click( |
|
fn=lambda: ("π§Ή Conversation reset. You can start a new one.", [], "", ""), |
|
inputs=[], |
|
outputs=[output_box, state, prompt_input, image_url_input] |
|
) |
|
|
|
demo.launch() |
|
|