|
import os |
|
import gradio as gr |
|
import requests |
|
import json |
|
import base64 |
|
from PIL import Image |
|
import io |
|
import logging |
|
import PyPDF2 |
|
import markdown |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
OPENROUTER_API_KEY = os.environ.get("OPENROUTER_API_KEY", "") |
|
|
|
|
|
MODELS = [ |
|
|
|
{"category": "Vision", "models": [ |
|
("Meta: Llama 3.2 11B Vision Instruct", "meta-llama/llama-3.2-11b-vision-instruct:free", 131072), |
|
("Qwen2.5 VL 72B Instruct", "qwen/qwen2.5-vl-72b-instruct:free", 131072), |
|
("Qwen2.5 VL 32B Instruct", "qwen/qwen2.5-vl-32b-instruct:free", 8192), |
|
("Qwen2.5 VL 7B Instruct", "qwen/qwen-2.5-vl-7b-instruct:free", 64000), |
|
("Qwen2.5 VL 3B Instruct", "qwen/qwen2.5-vl-3b-instruct:free", 64000), |
|
]}, |
|
|
|
|
|
{"category": "Gemini", "models": [ |
|
("Gemini Pro 2.0 Experimental", "google/gemini-2.0-pro-exp-02-05:free", 2000000), |
|
("Gemini Pro 2.5 Experimental", "google/gemini-2.5-pro-exp-03-25:free", 1000000), |
|
("Gemini 2.0 Flash Thinking Experimental", "google/gemini-2.0-flash-thinking-exp:free", 1048576), |
|
("Gemini Flash 2.0 Experimental", "google/gemini-2.0-flash-exp:free", 1048576), |
|
("Gemini Flash 1.5 8B Experimental", "google/gemini-flash-1.5-8b-exp", 1000000), |
|
("LearnLM 1.5 Pro Experimental", "google/learnlm-1.5-pro-experimental:free", 40960), |
|
]}, |
|
|
|
|
|
{"category": "Llama", "models": [ |
|
("Llama 3.3 70B Instruct", "meta-llama/llama-3.3-70b-instruct:free", 8000), |
|
("Llama 3.2 3B Instruct", "meta-llama/llama-3.2-3b-instruct:free", 20000), |
|
("Llama 3.2 1B Instruct", "meta-llama/llama-3.2-1b-instruct:free", 131072), |
|
("Llama 3.1 8B Instruct", "meta-llama/llama-3.1-8b-instruct:free", 131072), |
|
("Llama 3 8B Instruct", "meta-llama/llama-3-8b-instruct:free", 8192), |
|
("Llama 3.1 Nemotron 70B Instruct", "nvidia/llama-3.1-nemotron-70b-instruct:free", 131072), |
|
]}, |
|
|
|
|
|
{"category": "DeepSeek", "models": [ |
|
("DeepSeek R1 Zero", "deepseek/deepseek-r1-zero:free", 163840), |
|
("DeepSeek R1", "deepseek/deepseek-r1:free", 163840), |
|
("DeepSeek V3 Base", "deepseek/deepseek-v3-base:free", 131072), |
|
("DeepSeek V3 0324", "deepseek/deepseek-v3-0324:free", 131072), |
|
("DeepSeek V3", "deepseek/deepseek-chat:free", 131072), |
|
("DeepSeek R1 Distill Qwen 14B", "deepseek/deepseek-r1-distill-qwen-14b:free", 64000), |
|
("DeepSeek R1 Distill Qwen 32B", "deepseek/deepseek-r1-distill-qwen-32b:free", 16000), |
|
("DeepSeek R1 Distill Llama 70B", "deepseek/deepseek-r1-distill-llama-70b:free", 8192), |
|
]}, |
|
|
|
|
|
{"category": "Other Popular Models", "models": [ |
|
("Mistral Nemo", "mistralai/mistral-nemo:free", 128000), |
|
("Mistral Small 3.1 24B", "mistralai/mistral-small-3.1-24b-instruct:free", 96000), |
|
("Gemma 3 27B", "google/gemma-3-27b-it:free", 96000), |
|
("Gemma 3 12B", "google/gemma-3-12b-it:free", 131072), |
|
("Gemma 3 4B", "google/gemma-3-4b-it:free", 131072), |
|
("DeepHermes 3 Llama 3 8B Preview", "nousresearch/deephermes-3-llama-3-8b-preview:free", 131072), |
|
("Qwen2.5 72B Instruct", "qwen/qwen-2.5-72b-instruct:free", 32768), |
|
]}, |
|
|
|
|
|
{"category": "Smaller Models", "models": [ |
|
("Gemma 3 1B", "google/gemma-3-1b-it:free", 32768), |
|
("Gemma 2 9B", "google/gemma-2-9b-it:free", 8192), |
|
("Mistral 7B Instruct", "mistralai/mistral-7b-instruct:free", 8192), |
|
("Qwen 2 7B Instruct", "qwen/qwen-2-7b-instruct:free", 8192), |
|
("Phi-3 Mini 128K Instruct", "microsoft/phi-3-mini-128k-instruct:free", 8192), |
|
("Phi-3 Medium 128K Instruct", "microsoft/phi-3-medium-128k-instruct:free", 8192), |
|
("OpenChat 3.5 7B", "openchat/openchat-7b:free", 8192), |
|
("Zephyr 7B", "huggingfaceh4/zephyr-7b-beta:free", 4096), |
|
("MythoMax 13B", "gryphe/mythomax-l2-13b:free", 4096), |
|
]}, |
|
] |
|
|
|
|
|
ALL_MODELS = [] |
|
for category in MODELS: |
|
for model in category["models"]: |
|
ALL_MODELS.append(model) |
|
|
|
def format_to_message_dict(history): |
|
"""Convert history to proper message format""" |
|
messages = [] |
|
for pair in history: |
|
if len(pair) == 2: |
|
human, ai = pair |
|
if human: |
|
messages.append({"role": "user", "content": human}) |
|
if ai: |
|
messages.append({"role": "assistant", "content": ai}) |
|
return messages |
|
|
|
def encode_image_to_base64(image_path): |
|
"""Encode an image file to base64 string""" |
|
try: |
|
if isinstance(image_path, str): |
|
with open(image_path, "rb") as image_file: |
|
encoded_string = base64.b64encode(image_file.read()).decode('utf-8') |
|
file_extension = image_path.split('.')[-1].lower() |
|
mime_type = f"image/{file_extension}" |
|
if file_extension == "jpg" or file_extension == "jpeg": |
|
mime_type = "image/jpeg" |
|
return f"data:{mime_type};base64,{encoded_string}" |
|
else: |
|
buffered = io.BytesIO() |
|
image_path.save(buffered, format="PNG") |
|
encoded_string = base64.b64encode(buffered.getvalue()).decode('utf-8') |
|
return f"data:image/png;base64,{encoded_string}" |
|
except Exception as e: |
|
logger.error(f"Error encoding image: {str(e)}") |
|
return None |
|
|
|
def extract_text_from_file(file_path): |
|
"""Extract text from various file types""" |
|
try: |
|
file_extension = file_path.split('.')[-1].lower() |
|
|
|
if file_extension == 'pdf': |
|
text = "" |
|
with open(file_path, 'rb') as file: |
|
pdf_reader = PyPDF2.PdfReader(file) |
|
for page_num in range(len(pdf_reader.pages)): |
|
page = pdf_reader.pages[page_num] |
|
text += page.extract_text() + "\n\n" |
|
return text |
|
|
|
elif file_extension == 'md': |
|
with open(file_path, 'r', encoding='utf-8') as file: |
|
md_text = file.read() |
|
|
|
return md_text |
|
|
|
elif file_extension == 'txt': |
|
with open(file_path, 'r', encoding='utf-8') as file: |
|
return file.read() |
|
|
|
else: |
|
return f"Unsupported file type: {file_extension}" |
|
|
|
except Exception as e: |
|
logger.error(f"Error extracting text from file: {str(e)}") |
|
return f"Error processing file: {str(e)}" |
|
|
|
def prepare_message_with_media(text, images=None, documents=None): |
|
"""Prepare a message with text, images, and document content""" |
|
|
|
if not images and not documents: |
|
return text |
|
|
|
|
|
if documents and len(documents) > 0: |
|
|
|
document_texts = [] |
|
for doc in documents: |
|
if doc is None: |
|
continue |
|
doc_text = extract_text_from_file(doc) |
|
if doc_text: |
|
document_texts.append(doc_text) |
|
|
|
|
|
if document_texts: |
|
if not text: |
|
text = "Please analyze these documents:" |
|
else: |
|
text = f"{text}\n\nDocument content:\n\n" |
|
|
|
text += "\n\n".join(document_texts) |
|
|
|
|
|
if not images: |
|
return text |
|
|
|
|
|
content = [{"type": "text", "text": text}] |
|
|
|
|
|
if images: |
|
for img in images: |
|
if img is None: |
|
continue |
|
|
|
encoded_image = encode_image_to_base64(img) |
|
if encoded_image: |
|
content.append({ |
|
"type": "image_url", |
|
"image_url": {"url": encoded_image} |
|
}) |
|
|
|
return content |
|
|
|
def ask_ai(message, chatbot, model_choice, temperature, max_tokens, top_p, frequency_penalty, |
|
presence_penalty, images, documents, reasoning_effort): |
|
"""Enhanced AI query function with comprehensive options""" |
|
if not message.strip() and not images and not documents: |
|
return chatbot, "" |
|
|
|
|
|
model_id = None |
|
context_size = 0 |
|
for name, model_id_value, ctx_size in ALL_MODELS: |
|
if name == model_choice: |
|
model_id = model_id_value |
|
context_size = ctx_size |
|
break |
|
|
|
if model_id is None: |
|
logger.error(f"Model not found: {model_choice}") |
|
return chatbot + [[message, "Error: Model not found"]], "" |
|
|
|
|
|
messages = format_to_message_dict(chatbot) |
|
|
|
|
|
content = prepare_message_with_media(message, images, documents) |
|
|
|
|
|
messages.append({"role": "user", "content": content}) |
|
|
|
|
|
try: |
|
logger.info(f"Sending request to model: {model_id}") |
|
|
|
|
|
payload = { |
|
"model": model_id, |
|
"messages": messages, |
|
"temperature": temperature, |
|
"max_tokens": max_tokens, |
|
"top_p": top_p, |
|
"frequency_penalty": frequency_penalty, |
|
"presence_penalty": presence_penalty |
|
} |
|
|
|
|
|
if reasoning_effort != "none": |
|
payload["reasoning"] = { |
|
"effort": reasoning_effort |
|
} |
|
|
|
logger.info(f"Request payload: {json.dumps(payload, default=str)}") |
|
|
|
response = requests.post( |
|
"https://openrouter.ai/api/v1/chat/completions", |
|
headers={ |
|
"Content-Type": "application/json", |
|
"Authorization": f"Bearer {OPENROUTER_API_KEY}", |
|
"HTTP-Referer": "https://huggingface.co/spaces" |
|
}, |
|
json=payload, |
|
timeout=120 |
|
) |
|
|
|
logger.info(f"Response status: {response.status_code}") |
|
|
|
response_text = response.text |
|
logger.info(f"Response body: {response_text}") |
|
|
|
if response.status_code == 200: |
|
result = response.json() |
|
ai_response = result.get("choices", [{}])[0].get("message", {}).get("content", "") |
|
chatbot = chatbot + [[message, ai_response]] |
|
|
|
|
|
if "usage" in result: |
|
logger.info(f"Token usage: {result['usage']}") |
|
else: |
|
error_message = f"Error: Status code {response.status_code}\n\nResponse: {response_text}" |
|
chatbot = chatbot + [[message, error_message]] |
|
except Exception as e: |
|
logger.error(f"Exception during API call: {str(e)}") |
|
chatbot = chatbot + [[message, f"Error: {str(e)}"]] |
|
|
|
return chatbot, "" |
|
|
|
def clear_chat(): |
|
return [], "", [], [], 0.7, 1000, 0.8, 0.0, 0.0, "none" |
|
|
|
def filter_models(search_term): |
|
"""Filter models based on search term""" |
|
if not search_term: |
|
return gr.Dropdown.update(choices=[model[0] for model in ALL_MODELS], value=ALL_MODELS[0][0]) |
|
|
|
filtered_models = [model[0] for model in ALL_MODELS if search_term.lower() in model[0].lower()] |
|
|
|
if filtered_models: |
|
return gr.Dropdown.update(choices=filtered_models, value=filtered_models[0]) |
|
else: |
|
return gr.Dropdown.update(choices=[model[0] for model in ALL_MODELS], value=ALL_MODELS[0][0]) |
|
|
|
def get_model_info(model_name): |
|
"""Get model information by name""" |
|
for model in ALL_MODELS: |
|
if model[0] == model_name: |
|
return model |
|
return None |
|
|
|
def update_context_display(model_name): |
|
"""Update the context size display based on the selected model""" |
|
model_info = get_model_info(model_name) |
|
if model_info: |
|
name, model_id, context_size = model_info |
|
context_formatted = f"{context_size:,}" |
|
return f"{context_formatted} tokens" |
|
return "Unknown" |
|
|
|
|
|
with gr.Blocks(css=""" |
|
.context-size { |
|
font-size: 0.9em; |
|
color: #666; |
|
margin-left: 10px; |
|
} |
|
footer { display: none !important; } |
|
.model-selection-row { |
|
display: flex; |
|
align-items: center; |
|
} |
|
.parameter-grid { |
|
display: grid; |
|
grid-template-columns: 1fr 1fr; |
|
gap: 10px; |
|
} |
|
""") as demo: |
|
gr.Markdown(""" |
|
# Enhanced AI Chat |
|
|
|
Chat with various AI models from OpenRouter with support for images and documents. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
chatbot = gr.Chatbot( |
|
height=500, |
|
show_copy_button=True, |
|
show_label=False, |
|
avatar_images=(None, "https://upload.wikimedia.org/wikipedia/commons/0/04/ChatGPT_logo.svg") |
|
) |
|
|
|
with gr.Row(): |
|
message = gr.Textbox( |
|
placeholder="Type your message here...", |
|
label="Message", |
|
lines=2 |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
submit_btn = gr.Button("Send", variant="primary") |
|
|
|
with gr.Column(scale=1): |
|
clear_btn = gr.Button("Clear Chat", variant="secondary") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Accordion("Upload Images (for vision models)", open=False): |
|
images = gr.Gallery( |
|
label="Uploaded Images", |
|
show_label=True, |
|
columns=4, |
|
height="auto", |
|
object_fit="contain" |
|
) |
|
|
|
image_upload_btn = gr.UploadButton( |
|
label="Upload Images", |
|
file_types=["image"], |
|
file_count="multiple" |
|
) |
|
|
|
|
|
with gr.Accordion("Upload Documents (PDF, MD, TXT)", open=False): |
|
documents = gr.File( |
|
label="Uploaded Documents", |
|
file_types=[".pdf", ".md", ".txt"], |
|
file_count="multiple" |
|
) |
|
|
|
with gr.Column(scale=1): |
|
with gr.Group(): |
|
gr.Markdown("### Model Selection") |
|
|
|
with gr.Row(elem_classes="model-selection-row"): |
|
model_search = gr.Textbox( |
|
placeholder="Search models...", |
|
label="", |
|
show_label=False |
|
) |
|
|
|
with gr.Row(elem_classes="model-selection-row"): |
|
model_choice = gr.Dropdown( |
|
[model[0] for model in ALL_MODELS], |
|
value=ALL_MODELS[0][0], |
|
label="Model" |
|
) |
|
context_display = gr.Textbox( |
|
value=update_context_display(ALL_MODELS[0][0]), |
|
label="Context", |
|
interactive=False, |
|
elem_classes="context-size" |
|
) |
|
|
|
|
|
with gr.Accordion("Browse by Category", open=False): |
|
model_categories = gr.Radio( |
|
[category["category"] for category in MODELS], |
|
label="Categories", |
|
value=MODELS[0]["category"] |
|
) |
|
|
|
category_models = gr.Radio( |
|
[model[0] for model in MODELS[0]["models"]], |
|
label="Models in Category" |
|
) |
|
|
|
with gr.Accordion("Generation Parameters", open=False): |
|
with gr.Group(elem_classes="parameter-grid"): |
|
temperature = gr.Slider( |
|
minimum=0.0, |
|
maximum=2.0, |
|
value=0.7, |
|
step=0.1, |
|
label="Temperature" |
|
) |
|
|
|
max_tokens = gr.Slider( |
|
minimum=100, |
|
maximum=4000, |
|
value=1000, |
|
step=100, |
|
label="Max Tokens" |
|
) |
|
|
|
top_p = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.8, |
|
step=0.1, |
|
label="Top P" |
|
) |
|
|
|
frequency_penalty = gr.Slider( |
|
minimum=-2.0, |
|
maximum=2.0, |
|
value=0.0, |
|
step=0.1, |
|
label="Frequency Penalty" |
|
) |
|
|
|
presence_penalty = gr.Slider( |
|
minimum=-2.0, |
|
maximum=2.0, |
|
value=0.0, |
|
step=0.1, |
|
label="Presence Penalty" |
|
) |
|
|
|
reasoning_effort = gr.Radio( |
|
["none", "low", "medium", "high"], |
|
value="none", |
|
label="Reasoning Effort" |
|
) |
|
|
|
|
|
model_search.change( |
|
fn=filter_models, |
|
inputs=[model_search], |
|
outputs=[model_choice] |
|
) |
|
|
|
|
|
model_choice.change( |
|
fn=update_context_display, |
|
inputs=[model_choice], |
|
outputs=[context_display] |
|
) |
|
|
|
|
|
def update_category_models(category): |
|
for cat in MODELS: |
|
if cat["category"] == category: |
|
return gr.Radio.update(choices=[model[0] for model in cat["models"]], value=cat["models"][0][0]) |
|
return gr.Radio.update(choices=[], value=None) |
|
|
|
model_categories.change( |
|
fn=update_category_models, |
|
inputs=[model_categories], |
|
outputs=[category_models] |
|
) |
|
|
|
|
|
category_models.change( |
|
fn=lambda x: x, |
|
inputs=[category_models], |
|
outputs=[model_choice] |
|
) |
|
|
|
|
|
def process_uploaded_images(files): |
|
return [file.name for file in files] |
|
|
|
image_upload_btn.upload( |
|
fn=process_uploaded_images, |
|
inputs=[image_upload_btn], |
|
outputs=[images] |
|
) |
|
|
|
|
|
submit_btn.click( |
|
fn=ask_ai, |
|
inputs=[ |
|
message, chatbot, model_choice, temperature, max_tokens, |
|
top_p, frequency_penalty, presence_penalty, images, |
|
documents, reasoning_effort |
|
], |
|
outputs=[chatbot, message] |
|
) |
|
|
|
message.submit( |
|
fn=ask_ai, |
|
inputs=[ |
|
message, chatbot, model_choice, temperature, max_tokens, |
|
top_p, frequency_penalty, presence_penalty, images, |
|
documents, reasoning_effort |
|
], |
|
outputs=[chatbot, message] |
|
) |
|
|
|
clear_btn.click( |
|
fn=clear_chat, |
|
inputs=[], |
|
outputs=[ |
|
chatbot, message, images, documents, temperature, |
|
max_tokens, top_p, frequency_penalty, presence_penalty, reasoning_effort |
|
] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_name="0.0.0.0", server_port=7860) |