Spaces:
Build error
Build error
import gradio as gr | |
import httpx | |
import json | |
import asyncio | |
import os | |
from chat_state import chat_state | |
from config import OLLAMA_URL, DEFAULT_TEMPERATURE, DEFAULT_SYSTEM_MESSAGE | |
theme = gr.themes.Soft( | |
primary_hue="yellow", | |
neutral_hue="neutral", | |
text_size="md", | |
spacing_size="md", | |
radius_size="md", | |
font=[gr.themes.GoogleFont('Montserrat'), gr.themes.GoogleFont('ui-sans-serif'), 'system-ui', 'sans-serif'], | |
) | |
async def fetch_available_models(): | |
async with httpx.AsyncClient() as client: | |
try: | |
response = await client.get(f"{OLLAMA_URL}/api/tags") | |
response.raise_for_status() | |
data = response.json() | |
return [model["name"] for model in data.get("models", [])] | |
except httpx.HTTPStatusError as e: | |
print(f"Error fetching models: {e}") | |
return [] | |
async def get_model_info(model_name): | |
async with httpx.AsyncClient() as client: | |
try: | |
response = await client.post(f"{OLLAMA_URL}/api/show", json={"name": model_name}) | |
response.raise_for_status() | |
return response.json() | |
except httpx.HTTPStatusError as e: | |
print(f"Error fetching model info: {e}") | |
return {} | |
async def call_ollama_api(prompt, history): | |
messages = [{"role": "system", "content": chat_state.system_message}] | |
for user_msg, assistant_msg in history: | |
if user_msg: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
messages.append({"role": "user", "content": prompt}) | |
payload = { | |
"model": chat_state.model, | |
"messages": messages, | |
"stream": True, | |
"temperature": chat_state.temperature | |
} | |
async with httpx.AsyncClient() as client: | |
try: | |
async with client.stream("POST", f"{OLLAMA_URL}/api/chat", json=payload, timeout=30.0) as response: | |
response.raise_for_status() | |
full_response = "" | |
async for line in response.aiter_lines(): | |
if line: | |
json_line = json.loads(line) | |
message_content = json_line.get('message', {}).get('content', '') | |
if message_content: | |
full_response += message_content | |
yield full_response | |
if json_line.get('done'): | |
break | |
except httpx.HTTPStatusError as e: | |
yield f"An error occurred: {e}" | |
except asyncio.TimeoutError: | |
yield "The request timed out. Please try again." | |
async def user(user_message, history): | |
return "", history + [[user_message, None]] | |
async def bot(history): | |
user_message = history[-1][0] | |
bot_message_generator = call_ollama_api(user_message, history[:-1]) | |
async for message_content in bot_message_generator: | |
history[-1][1] = message_content | |
yield history | |
def clear_chat(): | |
return None | |
def save_chat_history(history, filename="chat_history.json"): | |
with open(filename, "w") as f: | |
json.dump(history, f) | |
return f"Chat history saved to {filename}" | |
def load_chat_history(filename="chat_history.json"): | |
try: | |
with open(filename, "r") as f: | |
return json.load(f) | |
except FileNotFoundError: | |
return None | |
async def change_model(model_name): | |
chat_state.model = model_name | |
model_info = await get_model_info(model_name) | |
info_text = f"Model: {model_name}\n" | |
info_text += f"Parameter Size: {model_info.get('details', {}).get('parameter_size', 'Unknown')}\n" | |
info_text += f"Quantization: {model_info.get('details', {}).get('quantization_level', 'Unknown')}\n" | |
info_text += f"Format: {model_info.get('details', {}).get('format', 'Unknown')}" | |
return f"Model changed to {chat_state.model}", info_text | |
def update_temperature(new_temp): | |
chat_state.temperature = float(new_temp) | |
return f"Temperature set to {chat_state.temperature}" | |
def update_system_message(new_message): | |
chat_state.system_message = new_message | |
return f"System message updated: {chat_state.system_message}" | |
async def initialize_interface(): | |
chat_state.available_models = await fetch_available_models() | |
with gr.Blocks(theme=theme) as demo: | |
gr.Markdown("# 🤖 Enhanced Ollama Chatbot Interface") | |
with gr.Row(): | |
with gr.Column(scale=7): | |
chatbot = gr.Chatbot(height=600, elem_id="chatbot") | |
with gr.Row(): | |
msg = gr.Textbox( | |
label="Message", | |
placeholder="Type your message here...", | |
scale=4, | |
elem_id="user-input" | |
) | |
send = gr.Button("Send", scale=1, elem_id="send-btn") | |
with gr.Column(scale=3): | |
with gr.Accordion("Model Settings", open=True): | |
model_dropdown = gr.Dropdown( | |
choices=chat_state.available_models, | |
label="Select Model", | |
value=chat_state.available_models[0] if chat_state.available_models else None, | |
elem_id="model-dropdown" | |
) | |
model_info = gr.Textbox(label="Model Information", interactive=False, lines=4) | |
temp_slider = gr.Slider( | |
minimum=0, maximum=1, value=DEFAULT_TEMPERATURE, step=0.1, | |
label="Temperature", | |
elem_id="temp-slider" | |
) | |
with gr.Accordion("System Message", open=False): | |
system_message_input = gr.Textbox( | |
label="System Message", | |
value=DEFAULT_SYSTEM_MESSAGE, | |
lines=3, | |
elem_id="system-message" | |
) | |
update_system_button = gr.Button("Update System Message", elem_id="update-system-btn") | |
with gr.Accordion("Chat Management", open=False): | |
with gr.Row(): | |
clear = gr.Button("Clear Chat", elem_id="clear-btn") | |
save_button = gr.Button("Save Chat", elem_id="save-btn") | |
load_button = gr.Button("Load Chat", elem_id="load-btn") | |
status_box = gr.Textbox(label="Status", interactive=False, elem_id="status-box") | |
# Event handlers | |
send_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( | |
bot, chatbot, chatbot | |
) | |
send.click(user, [msg, chatbot], [msg, chatbot], queue=False).then( | |
bot, chatbot, chatbot | |
) | |
clear.click(clear_chat, outputs=[chatbot]) | |
model_dropdown.change(change_model, inputs=[model_dropdown], outputs=[status_box, model_info]) | |
temp_slider.change(update_temperature, inputs=[temp_slider], outputs=[status_box]) | |
update_system_button.click(update_system_message, inputs=[system_message_input], outputs=[status_box]) | |
save_button.click(save_chat_history, inputs=[chatbot], outputs=[status_box]) | |
load_button.click(load_chat_history, outputs=[chatbot]) | |
# Initialize the first model | |
if chat_state.available_models: | |
chat_state.model = chat_state.available_models[0] | |
return demo | |
if __name__ == "__main__": | |
demo = asyncio.run(initialize_interface()) | |
demo.launch(share=True) |