Ollama_test / app.py
xtreme86's picture
Upload folder using huggingface_hub
5e1f2b6 verified
import gradio as gr
import httpx
import json
import asyncio
import os
from chat_state import chat_state
from config import OLLAMA_URL, DEFAULT_TEMPERATURE, DEFAULT_SYSTEM_MESSAGE
theme = gr.themes.Soft(
primary_hue="yellow",
neutral_hue="neutral",
text_size="md",
spacing_size="md",
radius_size="md",
font=[gr.themes.GoogleFont('Montserrat'), gr.themes.GoogleFont('ui-sans-serif'), 'system-ui', 'sans-serif'],
)
async def fetch_available_models():
async with httpx.AsyncClient() as client:
try:
response = await client.get(f"{OLLAMA_URL}/api/tags")
response.raise_for_status()
data = response.json()
return [model["name"] for model in data.get("models", [])]
except httpx.HTTPStatusError as e:
print(f"Error fetching models: {e}")
return []
async def get_model_info(model_name):
async with httpx.AsyncClient() as client:
try:
response = await client.post(f"{OLLAMA_URL}/api/show", json={"name": model_name})
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
print(f"Error fetching model info: {e}")
return {}
async def call_ollama_api(prompt, history):
messages = [{"role": "system", "content": chat_state.system_message}]
for user_msg, assistant_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": prompt})
payload = {
"model": chat_state.model,
"messages": messages,
"stream": True,
"temperature": chat_state.temperature
}
async with httpx.AsyncClient() as client:
try:
async with client.stream("POST", f"{OLLAMA_URL}/api/chat", json=payload, timeout=30.0) as response:
response.raise_for_status()
full_response = ""
async for line in response.aiter_lines():
if line:
json_line = json.loads(line)
message_content = json_line.get('message', {}).get('content', '')
if message_content:
full_response += message_content
yield full_response
if json_line.get('done'):
break
except httpx.HTTPStatusError as e:
yield f"An error occurred: {e}"
except asyncio.TimeoutError:
yield "The request timed out. Please try again."
async def user(user_message, history):
return "", history + [[user_message, None]]
async def bot(history):
user_message = history[-1][0]
bot_message_generator = call_ollama_api(user_message, history[:-1])
async for message_content in bot_message_generator:
history[-1][1] = message_content
yield history
def clear_chat():
return None
def save_chat_history(history, filename="chat_history.json"):
with open(filename, "w") as f:
json.dump(history, f)
return f"Chat history saved to {filename}"
def load_chat_history(filename="chat_history.json"):
try:
with open(filename, "r") as f:
return json.load(f)
except FileNotFoundError:
return None
async def change_model(model_name):
chat_state.model = model_name
model_info = await get_model_info(model_name)
info_text = f"Model: {model_name}\n"
info_text += f"Parameter Size: {model_info.get('details', {}).get('parameter_size', 'Unknown')}\n"
info_text += f"Quantization: {model_info.get('details', {}).get('quantization_level', 'Unknown')}\n"
info_text += f"Format: {model_info.get('details', {}).get('format', 'Unknown')}"
return f"Model changed to {chat_state.model}", info_text
def update_temperature(new_temp):
chat_state.temperature = float(new_temp)
return f"Temperature set to {chat_state.temperature}"
def update_system_message(new_message):
chat_state.system_message = new_message
return f"System message updated: {chat_state.system_message}"
async def initialize_interface():
chat_state.available_models = await fetch_available_models()
with gr.Blocks(theme=theme) as demo:
gr.Markdown("# 🤖 Enhanced Ollama Chatbot Interface")
with gr.Row():
with gr.Column(scale=7):
chatbot = gr.Chatbot(height=600, elem_id="chatbot")
with gr.Row():
msg = gr.Textbox(
label="Message",
placeholder="Type your message here...",
scale=4,
elem_id="user-input"
)
send = gr.Button("Send", scale=1, elem_id="send-btn")
with gr.Column(scale=3):
with gr.Accordion("Model Settings", open=True):
model_dropdown = gr.Dropdown(
choices=chat_state.available_models,
label="Select Model",
value=chat_state.available_models[0] if chat_state.available_models else None,
elem_id="model-dropdown"
)
model_info = gr.Textbox(label="Model Information", interactive=False, lines=4)
temp_slider = gr.Slider(
minimum=0, maximum=1, value=DEFAULT_TEMPERATURE, step=0.1,
label="Temperature",
elem_id="temp-slider"
)
with gr.Accordion("System Message", open=False):
system_message_input = gr.Textbox(
label="System Message",
value=DEFAULT_SYSTEM_MESSAGE,
lines=3,
elem_id="system-message"
)
update_system_button = gr.Button("Update System Message", elem_id="update-system-btn")
with gr.Accordion("Chat Management", open=False):
with gr.Row():
clear = gr.Button("Clear Chat", elem_id="clear-btn")
save_button = gr.Button("Save Chat", elem_id="save-btn")
load_button = gr.Button("Load Chat", elem_id="load-btn")
status_box = gr.Textbox(label="Status", interactive=False, elem_id="status-box")
# Event handlers
send_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
send.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(clear_chat, outputs=[chatbot])
model_dropdown.change(change_model, inputs=[model_dropdown], outputs=[status_box, model_info])
temp_slider.change(update_temperature, inputs=[temp_slider], outputs=[status_box])
update_system_button.click(update_system_message, inputs=[system_message_input], outputs=[status_box])
save_button.click(save_chat_history, inputs=[chatbot], outputs=[status_box])
load_button.click(load_chat_history, outputs=[chatbot])
# Initialize the first model
if chat_state.available_models:
chat_state.model = chat_state.available_models[0]
return demo
if __name__ == "__main__":
demo = asyncio.run(initialize_interface())
demo.launch(share=True)