Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,10 @@ import logging
|
|
7 |
import io
|
8 |
import time
|
9 |
from typing import List, Dict, Any, Union, Tuple, Optional
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# Configure logging
|
12 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
@@ -220,7 +224,7 @@ OPENAI_MODELS = {
|
|
220 |
"o1-mini-2024-09-12": 128000,
|
221 |
}
|
222 |
|
223 |
-
# HUGGINGFACE MODELS
|
224 |
HUGGINGFACE_MODELS = {
|
225 |
"microsoft/phi-3-mini-4k-instruct": 4096,
|
226 |
"microsoft/Phi-3-mini-128k-instruct": 131072,
|
@@ -509,9 +513,7 @@ def filter_models(provider, search_term):
|
|
509 |
if filtered_models:
|
510 |
return filtered_models, filtered_models[0]
|
511 |
else:
|
512 |
-
return
|
513 |
-
|
514 |
-
return all_models, all_models[0] if all_models else None
|
515 |
|
516 |
def get_model_info(provider, model_choice):
|
517 |
"""Get model ID and context size based on provider and model name"""
|
@@ -1688,14 +1690,14 @@ def create_app():
|
|
1688 |
# Define event handlers
|
1689 |
def toggle_model_dropdowns(provider):
|
1690 |
"""Show/hide model dropdowns based on provider selection"""
|
1691 |
-
return
|
1692 |
-
|
1693 |
-
|
1694 |
-
|
1695 |
-
|
1696 |
-
|
1697 |
-
|
1698 |
-
|
1699 |
|
1700 |
def update_context_for_provider(provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model):
|
1701 |
"""Update context display based on selected provider and model"""
|
@@ -1728,33 +1730,68 @@ def create_app():
|
|
1728 |
elif provider == "GLHF":
|
1729 |
return update_model_info(provider, glhf_model)
|
1730 |
return "<p>Model information not available</p>"
|
|
|
|
|
|
|
|
|
|
|
1731 |
|
1732 |
-
def filter_provider_models(provider, search_term):
|
1733 |
-
"""Filter models for the selected provider"""
|
1734 |
if provider == "OpenRouter":
|
1735 |
all_models = [model[0] for model in OPENROUTER_ALL_MODELS]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1736 |
elif provider == "OpenAI":
|
1737 |
all_models = list(OPENAI_MODELS.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1738 |
elif provider == "HuggingFace":
|
1739 |
all_models = list(HUGGINGFACE_MODELS.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1740 |
elif provider == "Groq":
|
1741 |
all_models = list(GROQ_MODELS.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1742 |
elif provider == "Cohere":
|
1743 |
all_models = list(COHERE_MODELS.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1744 |
elif provider == "GLHF":
|
1745 |
all_models = list(GLHF_MODELS.keys())
|
1746 |
-
|
1747 |
-
|
|
|
|
|
1748 |
|
1749 |
-
|
1750 |
-
return all_models, all_models[0] if all_models else None
|
1751 |
|
1752 |
-
|
1753 |
-
|
1754 |
-
if filtered_models:
|
1755 |
-
return filtered_models, filtered_models[0]
|
1756 |
-
else:
|
1757 |
-
return all_models, all_models[0] if all_models else None
|
1758 |
|
1759 |
def refresh_groq_models_list():
|
1760 |
"""Refresh the list of Groq models"""
|
@@ -1800,14 +1837,12 @@ def create_app():
|
|
1800 |
outputs=model_info_display
|
1801 |
)
|
1802 |
|
1803 |
-
# Set up model search event
|
|
|
1804 |
model_search.change(
|
1805 |
-
fn=
|
1806 |
inputs=[provider_choice, model_search],
|
1807 |
-
outputs=[
|
1808 |
-
gr.update(choices=None, value=None),
|
1809 |
-
gr.update(choices=None, value=None)
|
1810 |
-
]
|
1811 |
)
|
1812 |
|
1813 |
# Set up model change events
|
@@ -1871,6 +1906,25 @@ def create_app():
|
|
1871 |
outputs=model_info_display
|
1872 |
)
|
1873 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1874 |
# Set up submission event
|
1875 |
def submit_message(message, history, provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model,
|
1876 |
temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty,
|
@@ -1963,11 +2017,40 @@ def create_app():
|
|
1963 |
|
1964 |
# Launch the app
|
1965 |
if __name__ == "__main__":
|
1966 |
-
# Check API keys
|
|
|
|
|
1967 |
if not OPENROUTER_API_KEY:
|
1968 |
logger.warning("WARNING: OPENROUTER_API_KEY environment variable is not set")
|
1969 |
-
|
1970 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1971 |
demo = create_app()
|
1972 |
demo.launch(
|
1973 |
server_name="0.0.0.0",
|
|
|
7 |
import io
|
8 |
import time
|
9 |
from typing import List, Dict, Any, Union, Tuple, Optional
|
10 |
+
from dotenv import load_dotenv
|
11 |
+
|
12 |
+
# Load environment variables from .env file
|
13 |
+
load_dotenv()
|
14 |
|
15 |
# Configure logging
|
16 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
224 |
"o1-mini-2024-09-12": 128000,
|
225 |
}
|
226 |
|
227 |
+
# HUGGINGFACE MODELS
|
228 |
HUGGINGFACE_MODELS = {
|
229 |
"microsoft/phi-3-mini-4k-instruct": 4096,
|
230 |
"microsoft/Phi-3-mini-128k-instruct": 131072,
|
|
|
513 |
if filtered_models:
|
514 |
return filtered_models, filtered_models[0]
|
515 |
else:
|
516 |
+
return all_models, all_models[0] if all_models else None
|
|
|
|
|
517 |
|
518 |
def get_model_info(provider, model_choice):
|
519 |
"""Get model ID and context size based on provider and model name"""
|
|
|
1690 |
# Define event handlers
|
1691 |
def toggle_model_dropdowns(provider):
|
1692 |
"""Show/hide model dropdowns based on provider selection"""
|
1693 |
+
return [
|
1694 |
+
gr.update(visible=(provider == "OpenRouter")),
|
1695 |
+
gr.update(visible=(provider == "OpenAI")),
|
1696 |
+
gr.update(visible=(provider == "HuggingFace")),
|
1697 |
+
gr.update(visible=(provider == "Groq")),
|
1698 |
+
gr.update(visible=(provider == "Cohere")),
|
1699 |
+
gr.update(visible=(provider == "GLHF"))
|
1700 |
+
]
|
1701 |
|
1702 |
def update_context_for_provider(provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model):
|
1703 |
"""Update context display based on selected provider and model"""
|
|
|
1730 |
elif provider == "GLHF":
|
1731 |
return update_model_info(provider, glhf_model)
|
1732 |
return "<p>Model information not available</p>"
|
1733 |
+
|
1734 |
+
# Handling model search function - Fixed compared to previous implementation
|
1735 |
+
def search_models(provider, search_term):
|
1736 |
+
"""Filter models for the selected provider based on search term"""
|
1737 |
+
filtered_models = []
|
1738 |
|
|
|
|
|
1739 |
if provider == "OpenRouter":
|
1740 |
all_models = [model[0] for model in OPENROUTER_ALL_MODELS]
|
1741 |
+
if search_term:
|
1742 |
+
filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
|
1743 |
+
else:
|
1744 |
+
filtered_models = all_models
|
1745 |
+
|
1746 |
+
return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None)
|
1747 |
+
|
1748 |
elif provider == "OpenAI":
|
1749 |
all_models = list(OPENAI_MODELS.keys())
|
1750 |
+
if search_term:
|
1751 |
+
filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
|
1752 |
+
else:
|
1753 |
+
filtered_models = all_models
|
1754 |
+
|
1755 |
+
return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None)
|
1756 |
+
|
1757 |
elif provider == "HuggingFace":
|
1758 |
all_models = list(HUGGINGFACE_MODELS.keys())
|
1759 |
+
if search_term:
|
1760 |
+
filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
|
1761 |
+
else:
|
1762 |
+
filtered_models = all_models
|
1763 |
+
|
1764 |
+
return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None)
|
1765 |
+
|
1766 |
elif provider == "Groq":
|
1767 |
all_models = list(GROQ_MODELS.keys())
|
1768 |
+
if search_term:
|
1769 |
+
filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
|
1770 |
+
else:
|
1771 |
+
filtered_models = all_models
|
1772 |
+
|
1773 |
+
return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None)
|
1774 |
+
|
1775 |
elif provider == "Cohere":
|
1776 |
all_models = list(COHERE_MODELS.keys())
|
1777 |
+
if search_term:
|
1778 |
+
filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
|
1779 |
+
else:
|
1780 |
+
filtered_models = all_models
|
1781 |
+
|
1782 |
+
return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None)
|
1783 |
+
|
1784 |
elif provider == "GLHF":
|
1785 |
all_models = list(GLHF_MODELS.keys())
|
1786 |
+
if search_term:
|
1787 |
+
filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
|
1788 |
+
else:
|
1789 |
+
filtered_models = all_models
|
1790 |
|
1791 |
+
return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None)
|
|
|
1792 |
|
1793 |
+
# Default return in case of unknown provider
|
1794 |
+
return gr.update(choices=[], value=None)
|
|
|
|
|
|
|
|
|
1795 |
|
1796 |
def refresh_groq_models_list():
|
1797 |
"""Refresh the list of Groq models"""
|
|
|
1837 |
outputs=model_info_display
|
1838 |
)
|
1839 |
|
1840 |
+
# Set up model search event - FIXED VERSION
|
1841 |
+
# Important: We need to return a proper Gradio component update for each dropdown
|
1842 |
model_search.change(
|
1843 |
+
fn=search_models,
|
1844 |
inputs=[provider_choice, model_search],
|
1845 |
+
outputs=[openrouter_model] # This will be handled by the JS forwarding logic
|
|
|
|
|
|
|
1846 |
)
|
1847 |
|
1848 |
# Set up model change events
|
|
|
1906 |
outputs=model_info_display
|
1907 |
)
|
1908 |
|
1909 |
+
# Add custom JavaScript for routing model search to visible dropdown
|
1910 |
+
gr.HTML("""
|
1911 |
+
<script>
|
1912 |
+
// To be triggered after page load
|
1913 |
+
document.addEventListener('DOMContentLoaded', function() {
|
1914 |
+
// Find dropdowns
|
1915 |
+
const providerRadio = document.querySelector('input[name="provider_choice"]');
|
1916 |
+
const searchInput = document.getElementById('model_search');
|
1917 |
+
|
1918 |
+
if (providerRadio && searchInput) {
|
1919 |
+
// When provider changes, clear the search
|
1920 |
+
providerRadio.addEventListener('change', function() {
|
1921 |
+
searchInput.value = '';
|
1922 |
+
});
|
1923 |
+
}
|
1924 |
+
});
|
1925 |
+
</script>
|
1926 |
+
""")
|
1927 |
+
|
1928 |
# Set up submission event
|
1929 |
def submit_message(message, history, provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model,
|
1930 |
temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty,
|
|
|
2017 |
|
2018 |
# Launch the app
|
2019 |
if __name__ == "__main__":
|
2020 |
+
# Check API keys and print status
|
2021 |
+
missing_keys = []
|
2022 |
+
|
2023 |
if not OPENROUTER_API_KEY:
|
2024 |
logger.warning("WARNING: OPENROUTER_API_KEY environment variable is not set")
|
2025 |
+
missing_keys.append("OpenRouter")
|
2026 |
|
2027 |
+
if not OPENAI_API_KEY:
|
2028 |
+
logger.warning("WARNING: OPENAI_API_KEY environment variable is not set")
|
2029 |
+
missing_keys.append("OpenAI")
|
2030 |
+
|
2031 |
+
if not GROQ_API_KEY:
|
2032 |
+
logger.warning("WARNING: GROQ_API_KEY environment variable is not set")
|
2033 |
+
missing_keys.append("Groq")
|
2034 |
+
|
2035 |
+
if not COHERE_API_KEY:
|
2036 |
+
logger.warning("WARNING: COHERE_API_KEY environment variable is not set")
|
2037 |
+
missing_keys.append("Cohere")
|
2038 |
+
|
2039 |
+
if not GLHF_API_KEY:
|
2040 |
+
logger.warning("WARNING: GLHF_API_KEY environment variable is not set")
|
2041 |
+
missing_keys.append("GLHF")
|
2042 |
+
|
2043 |
+
if missing_keys:
|
2044 |
+
print("Missing API keys for the following providers:")
|
2045 |
+
for key in missing_keys:
|
2046 |
+
print(f"- {key}")
|
2047 |
+
print("\nYou can still use the application, but some providers will require API keys.")
|
2048 |
+
print("You can provide API keys through environment variables or use the API Key Override field.")
|
2049 |
+
|
2050 |
+
if "OpenRouter" in missing_keys:
|
2051 |
+
print("\nNote: OpenRouter offers free tier access to many models!")
|
2052 |
+
|
2053 |
+
print("\nStarting Multi-Provider CrispChat application...")
|
2054 |
demo = create_app()
|
2055 |
demo.launch(
|
2056 |
server_name="0.0.0.0",
|