cstr commited on
Commit
cf6c1c3
·
verified ·
1 Parent(s): 12009d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -32
app.py CHANGED
@@ -7,6 +7,10 @@ import logging
7
  import io
8
  import time
9
  from typing import List, Dict, Any, Union, Tuple, Optional
 
 
 
 
10
 
11
  # Configure logging
12
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@@ -220,7 +224,7 @@ OPENAI_MODELS = {
220
  "o1-mini-2024-09-12": 128000,
221
  }
222
 
223
- # HUGGINGFACE MODELS
224
  HUGGINGFACE_MODELS = {
225
  "microsoft/phi-3-mini-4k-instruct": 4096,
226
  "microsoft/Phi-3-mini-128k-instruct": 131072,
@@ -509,9 +513,7 @@ def filter_models(provider, search_term):
509
  if filtered_models:
510
  return filtered_models, filtered_models[0]
511
  else:
512
- return
513
-
514
- return all_models, all_models[0] if all_models else None
515
 
516
  def get_model_info(provider, model_choice):
517
  """Get model ID and context size based on provider and model name"""
@@ -1688,14 +1690,14 @@ def create_app():
1688
  # Define event handlers
1689
  def toggle_model_dropdowns(provider):
1690
  """Show/hide model dropdowns based on provider selection"""
1691
- return {
1692
- openrouter_model: gr.update(visible=(provider == "OpenRouter")),
1693
- openai_model: gr.update(visible=(provider == "OpenAI")),
1694
- hf_model: gr.update(visible=(provider == "HuggingFace")),
1695
- groq_model: gr.update(visible=(provider == "Groq")),
1696
- cohere_model: gr.update(visible=(provider == "Cohere")),
1697
- glhf_model: gr.update(visible=(provider == "GLHF"))
1698
- }
1699
 
1700
  def update_context_for_provider(provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model):
1701
  """Update context display based on selected provider and model"""
@@ -1728,33 +1730,68 @@ def create_app():
1728
  elif provider == "GLHF":
1729
  return update_model_info(provider, glhf_model)
1730
  return "<p>Model information not available</p>"
 
 
 
 
 
1731
 
1732
- def filter_provider_models(provider, search_term):
1733
- """Filter models for the selected provider"""
1734
  if provider == "OpenRouter":
1735
  all_models = [model[0] for model in OPENROUTER_ALL_MODELS]
 
 
 
 
 
 
 
1736
  elif provider == "OpenAI":
1737
  all_models = list(OPENAI_MODELS.keys())
 
 
 
 
 
 
 
1738
  elif provider == "HuggingFace":
1739
  all_models = list(HUGGINGFACE_MODELS.keys())
 
 
 
 
 
 
 
1740
  elif provider == "Groq":
1741
  all_models = list(GROQ_MODELS.keys())
 
 
 
 
 
 
 
1742
  elif provider == "Cohere":
1743
  all_models = list(COHERE_MODELS.keys())
 
 
 
 
 
 
 
1744
  elif provider == "GLHF":
1745
  all_models = list(GLHF_MODELS.keys())
1746
- else:
1747
- return [], None
 
 
1748
 
1749
- if not search_term:
1750
- return all_models, all_models[0] if all_models else None
1751
 
1752
- filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
1753
-
1754
- if filtered_models:
1755
- return filtered_models, filtered_models[0]
1756
- else:
1757
- return all_models, all_models[0] if all_models else None
1758
 
1759
  def refresh_groq_models_list():
1760
  """Refresh the list of Groq models"""
@@ -1800,14 +1837,12 @@ def create_app():
1800
  outputs=model_info_display
1801
  )
1802
 
1803
- # Set up model search event
 
1804
  model_search.change(
1805
- fn=lambda provider, search: filter_provider_models(provider, search),
1806
  inputs=[provider_choice, model_search],
1807
- outputs=[
1808
- gr.update(choices=None, value=None),
1809
- gr.update(choices=None, value=None)
1810
- ]
1811
  )
1812
 
1813
  # Set up model change events
@@ -1871,6 +1906,25 @@ def create_app():
1871
  outputs=model_info_display
1872
  )
1873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1874
  # Set up submission event
1875
  def submit_message(message, history, provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model,
1876
  temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty,
@@ -1963,11 +2017,40 @@ def create_app():
1963
 
1964
  # Launch the app
1965
  if __name__ == "__main__":
1966
- # Check API keys before starting
 
 
1967
  if not OPENROUTER_API_KEY:
1968
  logger.warning("WARNING: OPENROUTER_API_KEY environment variable is not set")
1969
- print("WARNING: OpenRouter API key not found. Set OPENROUTER_API_KEY environment variable to access free models.")
1970
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1971
  demo = create_app()
1972
  demo.launch(
1973
  server_name="0.0.0.0",
 
7
  import io
8
  import time
9
  from typing import List, Dict, Any, Union, Tuple, Optional
10
+ from dotenv import load_dotenv
11
+
12
+ # Load environment variables from .env file
13
+ load_dotenv()
14
 
15
  # Configure logging
16
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
224
  "o1-mini-2024-09-12": 128000,
225
  }
226
 
227
+ # HUGGINGFACE MODELS
228
  HUGGINGFACE_MODELS = {
229
  "microsoft/phi-3-mini-4k-instruct": 4096,
230
  "microsoft/Phi-3-mini-128k-instruct": 131072,
 
513
  if filtered_models:
514
  return filtered_models, filtered_models[0]
515
  else:
516
+ return all_models, all_models[0] if all_models else None
 
 
517
 
518
  def get_model_info(provider, model_choice):
519
  """Get model ID and context size based on provider and model name"""
 
1690
  # Define event handlers
1691
  def toggle_model_dropdowns(provider):
1692
  """Show/hide model dropdowns based on provider selection"""
1693
+ return [
1694
+ gr.update(visible=(provider == "OpenRouter")),
1695
+ gr.update(visible=(provider == "OpenAI")),
1696
+ gr.update(visible=(provider == "HuggingFace")),
1697
+ gr.update(visible=(provider == "Groq")),
1698
+ gr.update(visible=(provider == "Cohere")),
1699
+ gr.update(visible=(provider == "GLHF"))
1700
+ ]
1701
 
1702
  def update_context_for_provider(provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model):
1703
  """Update context display based on selected provider and model"""
 
1730
  elif provider == "GLHF":
1731
  return update_model_info(provider, glhf_model)
1732
  return "<p>Model information not available</p>"
1733
+
1734
+ # Handling model search function - Fixed compared to previous implementation
1735
+ def search_models(provider, search_term):
1736
+ """Filter models for the selected provider based on search term"""
1737
+ filtered_models = []
1738
 
 
 
1739
  if provider == "OpenRouter":
1740
  all_models = [model[0] for model in OPENROUTER_ALL_MODELS]
1741
+ if search_term:
1742
+ filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
1743
+ else:
1744
+ filtered_models = all_models
1745
+
1746
+ return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None)
1747
+
1748
  elif provider == "OpenAI":
1749
  all_models = list(OPENAI_MODELS.keys())
1750
+ if search_term:
1751
+ filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
1752
+ else:
1753
+ filtered_models = all_models
1754
+
1755
+ return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None)
1756
+
1757
  elif provider == "HuggingFace":
1758
  all_models = list(HUGGINGFACE_MODELS.keys())
1759
+ if search_term:
1760
+ filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
1761
+ else:
1762
+ filtered_models = all_models
1763
+
1764
+ return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None)
1765
+
1766
  elif provider == "Groq":
1767
  all_models = list(GROQ_MODELS.keys())
1768
+ if search_term:
1769
+ filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
1770
+ else:
1771
+ filtered_models = all_models
1772
+
1773
+ return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None)
1774
+
1775
  elif provider == "Cohere":
1776
  all_models = list(COHERE_MODELS.keys())
1777
+ if search_term:
1778
+ filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
1779
+ else:
1780
+ filtered_models = all_models
1781
+
1782
+ return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None)
1783
+
1784
  elif provider == "GLHF":
1785
  all_models = list(GLHF_MODELS.keys())
1786
+ if search_term:
1787
+ filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
1788
+ else:
1789
+ filtered_models = all_models
1790
 
1791
+ return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None)
 
1792
 
1793
+ # Default return in case of unknown provider
1794
+ return gr.update(choices=[], value=None)
 
 
 
 
1795
 
1796
  def refresh_groq_models_list():
1797
  """Refresh the list of Groq models"""
 
1837
  outputs=model_info_display
1838
  )
1839
 
1840
+ # Set up model search event - FIXED VERSION
1841
+ # Important: We need to return a proper Gradio component update for each dropdown
1842
  model_search.change(
1843
+ fn=search_models,
1844
  inputs=[provider_choice, model_search],
1845
+ outputs=[openrouter_model] # This will be handled by the JS forwarding logic
 
 
 
1846
  )
1847
 
1848
  # Set up model change events
 
1906
  outputs=model_info_display
1907
  )
1908
 
1909
+ # Add custom JavaScript for routing model search to visible dropdown
1910
+ gr.HTML("""
1911
+ <script>
1912
+ // To be triggered after page load
1913
+ document.addEventListener('DOMContentLoaded', function() {
1914
+ // Find dropdowns
1915
+ const providerRadio = document.querySelector('input[name="provider_choice"]');
1916
+ const searchInput = document.getElementById('model_search');
1917
+
1918
+ if (providerRadio && searchInput) {
1919
+ // When provider changes, clear the search
1920
+ providerRadio.addEventListener('change', function() {
1921
+ searchInput.value = '';
1922
+ });
1923
+ }
1924
+ });
1925
+ </script>
1926
+ """)
1927
+
1928
  # Set up submission event
1929
  def submit_message(message, history, provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model,
1930
  temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty,
 
2017
 
2018
  # Launch the app
2019
  if __name__ == "__main__":
2020
+ # Check API keys and print status
2021
+ missing_keys = []
2022
+
2023
  if not OPENROUTER_API_KEY:
2024
  logger.warning("WARNING: OPENROUTER_API_KEY environment variable is not set")
2025
+ missing_keys.append("OpenRouter")
2026
 
2027
+ if not OPENAI_API_KEY:
2028
+ logger.warning("WARNING: OPENAI_API_KEY environment variable is not set")
2029
+ missing_keys.append("OpenAI")
2030
+
2031
+ if not GROQ_API_KEY:
2032
+ logger.warning("WARNING: GROQ_API_KEY environment variable is not set")
2033
+ missing_keys.append("Groq")
2034
+
2035
+ if not COHERE_API_KEY:
2036
+ logger.warning("WARNING: COHERE_API_KEY environment variable is not set")
2037
+ missing_keys.append("Cohere")
2038
+
2039
+ if not GLHF_API_KEY:
2040
+ logger.warning("WARNING: GLHF_API_KEY environment variable is not set")
2041
+ missing_keys.append("GLHF")
2042
+
2043
+ if missing_keys:
2044
+ print("Missing API keys for the following providers:")
2045
+ for key in missing_keys:
2046
+ print(f"- {key}")
2047
+ print("\nYou can still use the application, but some providers will require API keys.")
2048
+ print("You can provide API keys through environment variables or use the API Key Override field.")
2049
+
2050
+ if "OpenRouter" in missing_keys:
2051
+ print("\nNote: OpenRouter offers free tier access to many models!")
2052
+
2053
+ print("\nStarting Multi-Provider CrispChat application...")
2054
  demo = create_app()
2055
  demo.launch(
2056
  server_name="0.0.0.0",