cstr commited on
Commit
547bcde
·
verified ·
1 Parent(s): 4b6c379

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1277 -342
app.py CHANGED
@@ -5,6 +5,7 @@ import json
5
  import base64
6
  import logging
7
  import io
 
8
  from typing import List, Dict, Any, Union, Tuple, Optional
9
 
10
  # Configure logging
@@ -14,37 +15,71 @@ logger = logging.getLogger(__name__)
14
  # Gracefully import libraries with fallbacks
15
  try:
16
  from PIL import Image
 
17
  except ImportError:
18
  logger.warning("PIL not installed. Image processing will be limited.")
19
- Image = None
20
 
21
  try:
22
  import PyPDF2
 
23
  except ImportError:
24
  logger.warning("PyPDF2 not installed. PDF processing will be limited.")
25
- PyPDF2 = None
26
 
27
  try:
28
  import markdown
 
29
  except ImportError:
30
  logger.warning("Markdown not installed. Markdown processing will be limited.")
31
- markdown = None
32
 
33
- # API key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  OPENROUTER_API_KEY = os.environ.get("OPENROUTER_API_KEY", "")
 
 
 
 
 
35
 
36
- # Log API key status (masked for security)
37
- if OPENROUTER_API_KEY:
38
- masked_key = OPENROUTER_API_KEY[:4] + "..." + OPENROUTER_API_KEY[-4:] if len(OPENROUTER_API_KEY) > 8 else "***"
39
- logger.info(f"Using API key: {masked_key}")
40
- else:
41
- logger.warning("No API key provided!")
42
 
43
- # Keep the existing model lists
44
- MODELS = [
 
45
  # 1M+ Context Models
46
  {"category": "1M+ Context", "models": [
47
- ("Google: Gemini Pro 2.0 Experimental", "google/gemini-2.0-pro-exp-02-05:free", 2000000),
48
  ("Google: Gemini 2.0 Flash Thinking Experimental 01-21", "google/gemini-2.0-flash-thinking-exp:free", 1048576),
49
  ("Google: Gemini Flash 2.0 Experimental", "google/gemini-2.0-flash-exp:free", 1048576),
50
  ("Google: Gemini Pro 2.5 Experimental", "google/gemini-2.5-pro-exp-03-25:free", 1000000),
@@ -125,7 +160,7 @@ MODELS = [
125
 
126
  # Vision-capable Models
127
  {"category": "Vision Models", "models": [
128
- ("Google: Gemini Pro 2.0 Experimental", "google/gemini-2.0-pro-exp-02-05:free", 2000000),
129
  ("Google: Gemini 2.0 Flash Thinking Experimental 01-21", "google/gemini-2.0-flash-thinking-exp:free", 1048576),
130
  ("Google: Gemini Flash 2.0 Experimental", "google/gemini-2.0-flash-exp:free", 1048576),
131
  ("Google: Gemini Pro 2.5 Experimental", "google/gemini-2.5-pro-exp-03-25:free", 1000000),
@@ -147,89 +182,159 @@ MODELS = [
147
  ]},
148
  ]
149
 
150
- # Flatten model list for easy searching
151
- ALL_MODELS = []
152
- for category in MODELS:
153
  for model in category["models"]:
154
- if model not in ALL_MODELS: # Avoid duplicates
155
- ALL_MODELS.append(model)
156
 
157
- # Helper functions moved to the top to avoid undefined references
158
- def filter_models(search_term):
159
- """Filter models based on search term"""
160
- if not search_term:
161
- return [model[0] for model in ALL_MODELS], ALL_MODELS[0][0]
162
-
163
- filtered_models = [model[0] for model in ALL_MODELS if search_term.lower() in model[0].lower()]
164
-
165
- if filtered_models:
166
- return filtered_models, filtered_models[0]
167
- else:
168
- return [model[0] for model in ALL_MODELS], ALL_MODELS[0][0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- def update_context_display(model_name):
171
- """Update context size display for the selected model"""
172
- for model in ALL_MODELS:
173
- if model[0] == model_name:
174
- _, _, context_size = model
175
- context_formatted = f"{context_size:,}"
176
- return f"{context_formatted} tokens"
177
- return "Unknown"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
- def update_model_info(model_name):
180
- """Generate HTML info display for the selected model"""
181
- for model in ALL_MODELS:
182
- if model[0] == model_name:
183
- name, model_id, context_size = model
184
-
185
- # Check if this is a vision model
186
- is_vision_model = False
187
- for cat in MODELS:
188
- if cat["category"] == "Vision Models":
189
- if any(m[0] == model_name for m in cat["models"]):
190
- is_vision_model = True
191
- break
192
-
193
- vision_badge = '<span style="background-color: #4CAF50; color: white; padding: 3px 6px; border-radius: 3px; font-size: 0.8em; margin-left: 5px;">Vision</span>' if is_vision_model else ''
194
-
195
- return f"""
196
- <div class="model-info">
197
- <h3>{name} {vision_badge}</h3>
198
- <p><strong>Model ID:</strong> {model_id}</p>
199
- <p><strong>Context Size:</strong> {context_size:,} tokens</p>
200
- <p><strong>Provider:</strong> {model_id.split('/')[0]}</p>
201
- {f'<p><strong>Features:</strong> Supports image understanding</p>' if is_vision_model else ''}
202
- </div>
203
- """
204
- return "<p>Model information not available</p>"
205
 
206
- def update_category_dropdown(category):
207
- """Update the category model dropdown when a category is selected"""
208
- models = get_models_for_category(category)
209
- if not models:
210
- return [], None
211
- return models, models[0]
 
 
 
 
 
 
 
 
 
212
 
213
- def update_category_models_ui(category):
214
- """Completely regenerate the models dropdown based on selected category"""
215
- for cat in MODELS:
216
- if cat["category"] == category:
217
- model_names = [model[0] for model in cat["models"]]
218
- if model_names:
219
- # Return a completely new dropdown component
220
- return gr.Dropdown(
221
- choices=model_names,
222
- value=model_names[0],
223
- label="Models in Category",
224
- allow_custom_value=True
225
- )
226
- # Return empty dropdown if no models found
227
- return gr.Dropdown(
228
- choices=[],
229
- value=None,
230
- label="Models in Category",
231
- allow_custom_value=True
232
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  def encode_image_to_base64(image_path):
235
  """Encode an image file to base64 string"""
@@ -271,7 +376,7 @@ def extract_text_from_file(file_path):
271
  file_extension = file_path.split('.')[-1].lower()
272
 
273
  if file_extension == 'pdf':
274
- if PyPDF2 is not None:
275
  text = ""
276
  with open(file_path, 'rb') as file:
277
  pdf_reader = PyPDF2.PdfReader(file)
@@ -379,74 +484,419 @@ def process_uploaded_images(files):
379
  file_paths.append(file.name)
380
  return file_paths
381
 
382
- def get_model_info(model_choice):
383
- """Get model ID and context size from model name"""
384
- for name, model_id_value, ctx_size in ALL_MODELS:
385
- if name == model_choice:
386
- return model_id_value, ctx_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  return None, 0
388
 
389
- def get_models_for_category(category):
390
- """Get model list for a specific category"""
391
- for cat in MODELS:
392
- if cat["category"] == category:
393
- return [model[0] for model in cat["models"]]
394
- return []
395
 
396
- def call_openrouter_api(payload):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  """Make a call to OpenRouter API with error handling"""
398
  try:
 
 
 
 
399
  response = requests.post(
400
  "https://openrouter.ai/api/v1/chat/completions",
401
  headers={
402
  "Content-Type": "application/json",
403
- "Authorization": f"Bearer {OPENROUTER_API_KEY}",
404
- "HTTP-Referer": "https://huggingface.co/spaces/cstr/CrispChat"
405
  },
406
  json=payload,
407
  timeout=180 # Longer timeout for document processing
408
  )
409
  return response
410
  except requests.RequestException as e:
411
- logger.error(f"API request error: {str(e)}")
412
  raise e
413
 
414
- def extract_ai_response(result):
415
- """Extract AI response from OpenRouter API result"""
416
  try:
417
- if "choices" in result and len(result["choices"]) > 0:
418
- if "message" in result["choices"][0]:
419
- # Handle reasoning field if available
420
- message = result["choices"][0]["message"]
421
- if message.get("reasoning") and not message.get("content"):
422
- # Extract response from reasoning if there's no content
423
- reasoning = message.get("reasoning")
424
- # If reasoning contains the actual response, find it
425
- lines = reasoning.strip().split('\n')
426
- for line in lines:
427
- if line and not line.startswith('I should') and not line.startswith('Let me'):
428
- return line.strip()
429
- # If no clear response found, return the first non-empty line
430
- for line in lines:
431
- if line.strip():
432
- return line.strip()
433
- return message.get("content", "")
434
- elif "delta" in result["choices"][0]:
435
- return result["choices"][0]["delta"].get("content", "")
436
 
437
- logger.error(f"Unexpected response structure: {result}")
438
- return "Error: Could not extract response from API result"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  except Exception as e:
440
  logger.error(f"Error extracting AI response: {str(e)}")
441
  return f"Error: {str(e)}"
442
 
443
- # streaming code:
444
- def streaming_handler(response, chatbot, message_idx, message):
 
 
 
445
  try:
446
  # First add the user message if needed
447
  if len(chatbot) == message_idx:
448
- chatbot.append({"role": "user", "content": message})
449
- chatbot.append({"role": "assistant", "content": ""})
450
 
451
  for line in response.iter_lines():
452
  if not line:
@@ -465,8 +915,8 @@ def streaming_handler(response, chatbot, message_idx, message):
465
  if "choices" in chunk and len(chunk["choices"]) > 0:
466
  delta = chunk["choices"][0].get("delta", {})
467
  if "content" in delta and delta["content"]:
468
- # Update the last message content
469
- chatbot[-1]["content"] += delta["content"]
470
  yield chatbot
471
  except json.JSONDecodeError:
472
  logger.error(f"Failed to parse JSON from chunk: {data}")
@@ -474,25 +924,82 @@ def streaming_handler(response, chatbot, message_idx, message):
474
  logger.error(f"Error in streaming handler: {str(e)}")
475
  # Add error message to the current response
476
  if len(chatbot) > message_idx:
477
- chatbot[-1]["content"] += f"\n\nError during streaming: {str(e)}"
478
  yield chatbot
479
 
480
- def ask_ai(message, history, model_choice, temperature, max_tokens, top_p,
481
- frequency_penalty, presence_penalty, repetition_penalty, top_k,
482
- min_p, seed, top_a, stream_output, response_format,
483
- images, documents, reasoning_effort, system_message, transforms):
484
- """Redesigned AI query function with proper error handling for Gradio 4.44.1"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
  # Validate input
486
  if not message.strip() and not images and not documents:
487
  return history
488
 
489
- # Get model information
490
- model_id, context_size = get_model_info(model_choice)
491
- if not model_id:
492
- logger.error(f"Model not found: {model_choice}")
493
- history.append((message, f"Error: Model '{model_choice}' not found"))
494
- return history
495
-
496
  # Copy history to new list to avoid modifying the original
497
  chat_history = list(history)
498
 
@@ -512,10 +1019,8 @@ def ask_ai(message, history, model_choice, temperature, max_tokens, top_p,
512
  # Add current message
513
  messages.append({"role": "user", "content": content})
514
 
515
- # Build the payload with all parameters
516
- payload = {
517
- "model": model_id,
518
- "messages": messages,
519
  "temperature": temperature,
520
  "max_tokens": max_tokens,
521
  "top_p": top_p,
@@ -524,84 +1029,302 @@ def ask_ai(message, history, model_choice, temperature, max_tokens, top_p,
524
  "stream": stream_output
525
  }
526
 
527
- # Add optional parameters if set
528
- if repetition_penalty != 1.0:
529
- payload["repetition_penalty"] = repetition_penalty
530
-
531
- if top_k > 0:
532
- payload["top_k"] = top_k
533
-
534
- if min_p > 0:
535
- payload["min_p"] = min_p
536
-
537
- if seed > 0:
538
- payload["seed"] = seed
539
-
540
- if top_a > 0:
541
- payload["top_a"] = top_a
542
-
543
- # Add response format if JSON is requested
544
- if response_format == "json_object":
545
- payload["response_format"] = {"type": "json_object"}
546
-
547
- # Add reasoning if selected
548
- if reasoning_effort != "none":
549
- payload["reasoning"] = {
550
- "effort": reasoning_effort
551
- }
552
-
553
- # Add transforms if selected
554
- if transforms:
555
- payload["transforms"] = transforms
556
-
557
- # Log the request
558
- logger.info(f"Sending request to model: {model_id}")
559
- logger.info(f"Request payload: {json.dumps(payload, default=str)}")
560
-
561
  try:
562
- # Call OpenRouter API
563
- response = call_openrouter_api(payload)
564
- logger.info(f"Response status: {response.status_code}")
565
-
566
- # Handle streaming response
567
- if stream_output and response.status_code == 200:
568
- # Add empty response slot to history
569
- chat_history.append([message, ""])
 
 
 
 
 
 
 
570
 
571
- # Set up generator for streaming updates
572
- def streaming_generator():
573
- for updated_history in streaming_handler(response, chat_history, len(chat_history) - 1, message):
574
- yield updated_history
575
 
576
- return streaming_generator()
577
-
578
- # Handle normal response
579
- elif response.status_code == 200:
580
- result = response.json()
581
- logger.info(f"Response content: {result}")
582
 
583
- # Extract AI response
584
- ai_response = extract_ai_response(result)
585
 
586
- # Log token usage if available
587
- if "usage" in result:
588
- logger.info(f"Token usage: {result['usage']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
 
590
- # Add response to history
591
- chat_history.append({"role": "user", "content": message})
592
- chat_history.append({"role": "assistant", "content": ai_response})
593
- return chat_history
594
-
595
- # Handle error response
596
- else:
597
- error_message = f"Error: Status code {response.status_code}"
598
  try:
599
- response_data = response.json()
600
- error_message += f"\n\nDetails: {json.dumps(response_data, indent=2)}"
601
- except:
602
- error_message += f"\n\nResponse: {response.text}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
603
 
604
- logger.error(error_message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
  chat_history.append([message, error_message])
606
  return chat_history
607
 
@@ -615,10 +1338,14 @@ def clear_chat():
615
  """Reset all inputs"""
616
  return [], "", [], [], 0.7, 1000, 0.8, 0.0, 0.0, 1.0, 40, 0.1, 0, 0.0, False, "default", "none", "", []
617
 
 
 
 
 
618
  def create_app():
619
- """Create the Gradio application with improved UI and response handling"""
620
  with gr.Blocks(
621
- title="CrispChat - AI Assistant",
622
  css="""
623
  .context-size {
624
  font-size: 0.9em;
@@ -643,30 +1370,30 @@ def create_app():
643
  font-size: 0.8em;
644
  margin-left: 5px;
645
  }
 
 
 
 
 
 
646
  """
647
  ) as demo:
648
  gr.Markdown("""
649
- # CrispChat AI Assistant
650
 
651
- Chat with various AI models from OpenRouter with support for images and documents.
652
  """)
653
 
654
  with gr.Row():
655
  with gr.Column(scale=2):
656
- # Chatbot interface - properly configured for Gradio 4.44.1
657
  chatbot = gr.Chatbot(
658
  height=500,
659
  show_copy_button=True,
660
  show_label=False,
661
  avatar_images=(None, "https://upload.wikimedia.org/wikipedia/commons/0/04/ChatGPT_logo.svg"),
662
- type="messages", # Explicitly set the type to messages
663
- elem_id="chat-window" # Add elem_id for debugging
664
- )
665
-
666
- # Debug output for development
667
- debug_output = gr.JSON(
668
- label="Debug Output (Hidden in Production)",
669
- visible=False
670
  )
671
 
672
  with gr.Row():
@@ -674,7 +1401,7 @@ def create_app():
674
  placeholder="Type your message here...",
675
  label="Message",
676
  lines=2,
677
- elem_id="message-input", # Add elem_id for debugging
678
  scale=4
679
  )
680
 
@@ -709,6 +1436,23 @@ def create_app():
709
  )
710
 
711
  with gr.Column(scale=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  with gr.Group():
713
  gr.Markdown("### Model Selection")
714
 
@@ -719,39 +1463,61 @@ def create_app():
719
  show_label=False
720
  )
721
 
722
- with gr.Row(elem_classes="model-selection-row"):
723
-
724
- # Main model dropdown
725
- model_choice = gr.Dropdown(
726
- [model[0] for model in ALL_MODELS],
727
- value=ALL_MODELS[0][0],
728
- label="Model",
729
- elem_id="model-choice",
730
- allow_custom_value=True
731
- )
732
-
733
- context_display = gr.Textbox(
734
- value=update_context_display(ALL_MODELS[0][0]),
735
- label="Context",
736
- interactive=False,
737
- elem_classes="context-size"
738
- )
739
 
740
- # Model category selection
741
- with gr.Accordion("Browse by Category", open=False):
742
- model_categories = gr.Dropdown(
743
- [model["category"] for model in MODELS],
744
- label="Categories",
745
- value=MODELS[0]["category"]
746
- )
747
-
748
- # Models in category dropdown
749
- category_models = gr.Dropdown(
750
- get_models_for_category(MODELS[0]["category"]),
751
- label="Models in Category",
752
- value=get_models_for_category(MODELS[0]["category"])[0] if get_models_for_category(MODELS[0]["category"]) else None,
753
- allow_custom_value=True
754
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755
 
756
  with gr.Accordion("Generation Parameters", open=False):
757
  with gr.Group(elem_classes="parameter-grid"):
@@ -798,7 +1564,7 @@ def create_app():
798
  reasoning_effort = gr.Radio(
799
  ["none", "low", "medium", "high"],
800
  value="none",
801
- label="Reasoning Effort"
802
  )
803
 
804
  with gr.Accordion("Advanced Options", open=False):
@@ -857,7 +1623,7 @@ def create_app():
857
 
858
  gr.Markdown("""
859
  * **json_object**: Forces the model to respond with valid JSON only.
860
- * Only available on certain models - check model support on OpenRouter.
861
  """)
862
 
863
  # Custom instructing options
@@ -882,7 +1648,7 @@ def create_app():
882
  # Add a model information section
883
  with gr.Accordion("About Selected Model", open=False):
884
  model_info_display = gr.HTML(
885
- value=update_model_info(ALL_MODELS[0][0])
886
  )
887
 
888
  # Add usage instructions
@@ -890,88 +1656,270 @@ def create_app():
890
  gr.Markdown("""
891
  ## Basic Usage
892
  1. Type your message in the input box
893
- 2. Select a model from the dropdown
894
  3. Click "Send" or press Enter
895
 
896
  ## Working with Files
897
  - **Images**: Upload images to use with vision-capable models
898
  - **Documents**: Upload PDF, Markdown, or text files to analyze their content
899
 
 
 
 
 
 
 
 
 
900
  ## Advanced Parameters
901
  - **Temperature**: Controls randomness (higher = more creative, lower = more deterministic)
902
  - **Max Tokens**: Maximum length of the response
903
  - **Top P**: Nucleus sampling threshold (higher = consider more tokens)
904
- - **Reasoning Effort**: Some models can show their reasoning process
905
-
906
- ## Tips
907
- - For code generation, use models like Qwen Coder
908
- - For visual tasks, choose vision-capable models
909
- - For long context, check the context window size next to the model name
910
  """)
911
 
912
  # Add a footer with version info
913
  footer_md = gr.Markdown("""
914
  ---
915
- ### CrispChat v1.1
916
- Built with ❤️ using Gradio 4.44.1 and OpenRouter API | Context sizes shown next to model names
917
  """)
918
 
919
- # Define a test function for debugging
920
- def test_chatbot(test_message):
921
- """Simple test function to verify chatbot updates work"""
922
- logger.info(f"Test function called with: {test_message}")
923
- return [[test_message, "This is a test response to verify the chatbot is working"]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
924
 
925
- # Connect model search to dropdown filter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
926
  model_search.change(
927
- fn=filter_models,
928
- inputs=model_search,
929
- outputs=[model_choice, model_choice]
 
 
 
930
  )
931
 
932
- # Update context display when model changes
933
- model_choice.change(
934
- fn=update_context_display,
935
- inputs=model_choice,
936
  outputs=context_display
 
 
 
 
937
  )
938
 
939
- # Update model info when model changes
940
- model_choice.change(
941
- fn=update_model_info,
942
- inputs=model_choice,
 
 
 
943
  outputs=model_info_display
944
  )
945
 
946
- # Update model list when category changes
947
- model_categories.change(
948
- fn=update_category_models_ui,
949
- inputs=model_categories,
950
- outputs=category_models
951
- )
 
 
 
952
 
953
- # Update main model choice when category model is selected
954
- category_models.change(
955
- fn=lambda x: x,
956
- inputs=category_models,
957
- outputs=model_choice
 
 
 
958
  )
959
-
960
- # Process uploaded images
961
- image_upload_btn.upload(
962
- fn=lambda files: files,
963
- inputs=image_upload_btn,
964
- outputs=images
 
 
 
 
 
 
 
 
 
 
 
 
 
965
  )
966
 
967
- # Set up events for the submit button
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
968
  submit_btn.click(
969
- fn=ask_ai,
970
  inputs=[
971
- message, chatbot, model_choice, temperature, max_tokens,
972
- top_p, frequency_penalty, presence_penalty, repetition_penalty,
 
973
  top_k, min_p, seed, top_a, stream_output, response_format,
974
- images, documents, reasoning_effort, system_message, transforms
975
  ],
976
  outputs=chatbot,
977
  show_progress="minimal",
@@ -981,14 +1929,15 @@ def create_app():
981
  outputs=message
982
  )
983
 
984
- # Set up events for message submission (pressing Enter)
985
  message.submit(
986
- fn=ask_ai,
987
  inputs=[
988
- message, chatbot, model_choice, temperature, max_tokens,
989
- top_p, frequency_penalty, presence_penalty, repetition_penalty,
 
990
  top_k, min_p, seed, top_a, stream_output, response_format,
991
- images, documents, reasoning_effort, system_message, transforms
992
  ],
993
  outputs=chatbot,
994
  show_progress="minimal",
@@ -998,7 +1947,7 @@ def create_app():
998
  outputs=message
999
  )
1000
 
1001
- # Set up events for the clear button
1002
  clear_btn.click(
1003
  fn=clear_chat,
1004
  inputs=[],
@@ -1010,28 +1959,14 @@ def create_app():
1010
  ]
1011
  )
1012
 
1013
- # Debug button (hidden in production)
1014
- debug_btn = gr.Button("Debug Chatbot", visible=False)
1015
- debug_btn.click(
1016
- fn=test_chatbot,
1017
- inputs=[message],
1018
- outputs=[chatbot]
1019
- )
1020
-
1021
- # Enable debugging for key components
1022
- # gr.debug(chatbot)
1023
-
1024
  return demo
1025
 
1026
-
1027
-
1028
-
1029
  # Launch the app
1030
  if __name__ == "__main__":
1031
- # Check API key before starting
1032
  if not OPENROUTER_API_KEY:
1033
  logger.warning("WARNING: OPENROUTER_API_KEY environment variable is not set")
1034
- print("WARNING: OpenRouter API key not found. Set OPENROUTER_API_KEY environment variable.")
1035
 
1036
  demo = create_app()
1037
  demo.launch(
 
5
  import base64
6
  import logging
7
  import io
8
+ import time
9
  from typing import List, Dict, Any, Union, Tuple, Optional
10
 
11
  # Configure logging
 
15
  # Gracefully import libraries with fallbacks
16
  try:
17
  from PIL import Image
18
+ HAS_PIL = True
19
  except ImportError:
20
  logger.warning("PIL not installed. Image processing will be limited.")
21
+ HAS_PIL = False
22
 
23
  try:
24
  import PyPDF2
25
+ HAS_PYPDF2 = True
26
  except ImportError:
27
  logger.warning("PyPDF2 not installed. PDF processing will be limited.")
28
+ HAS_PYPDF2 = False
29
 
30
  try:
31
  import markdown
32
+ HAS_MARKDOWN = True
33
  except ImportError:
34
  logger.warning("Markdown not installed. Markdown processing will be limited.")
35
+ HAS_MARKDOWN = False
36
 
37
+ try:
38
+ import openai
39
+ HAS_OPENAI = True
40
+ except ImportError:
41
+ logger.warning("OpenAI package not installed. OpenAI models will be unavailable.")
42
+ HAS_OPENAI = False
43
+
44
+ try:
45
+ from groq import Groq
46
+ HAS_GROQ = True
47
+ except ImportError:
48
+ logger.warning("Groq client not installed. Groq API will be unavailable.")
49
+ HAS_GROQ = False
50
+
51
+ try:
52
+ import cohere
53
+ HAS_COHERE = True
54
+ except ImportError:
55
+ logger.warning("Cohere package not installed. Cohere models will be unavailable.")
56
+ HAS_COHERE = False
57
+
58
+ try:
59
+ from huggingface_hub import InferenceClient
60
+ HAS_HF = True
61
+ except ImportError:
62
+ logger.warning("HuggingFace hub not installed. HuggingFace models will be limited.")
63
+ HAS_HF = False
64
+
65
+ # API keys from environment
66
  OPENROUTER_API_KEY = os.environ.get("OPENROUTER_API_KEY", "")
67
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
68
+ GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
69
+ COHERE_API_KEY = os.environ.get("COHERE_API_KEY", "")
70
+ GLHF_API_KEY = os.environ.get("GLHF_API_KEY", "")
71
+ HF_API_KEY = os.environ.get("HF_API_KEY", "")
72
 
73
+ # ==========================================================
74
+ # MODEL DEFINITIONS
75
+ # ==========================================================
 
 
 
76
 
77
+ # OPENROUTER MODELS
78
+ # These are the original models from the provided code
79
+ OPENROUTER_MODELS = [
80
  # 1M+ Context Models
81
  {"category": "1M+ Context", "models": [
82
+ #("Google: Gemini Pro 2.0 Experimental", "google/gemini-2.0-pro-exp-02-05:free", 2000000),
83
  ("Google: Gemini 2.0 Flash Thinking Experimental 01-21", "google/gemini-2.0-flash-thinking-exp:free", 1048576),
84
  ("Google: Gemini Flash 2.0 Experimental", "google/gemini-2.0-flash-exp:free", 1048576),
85
  ("Google: Gemini Pro 2.5 Experimental", "google/gemini-2.5-pro-exp-03-25:free", 1000000),
 
160
 
161
  # Vision-capable Models
162
  {"category": "Vision Models", "models": [
163
+ #("Google: Gemini Pro 2.0 Experimental", "google/gemini-2.0-pro-exp-02-05:free", 2000000),
164
  ("Google: Gemini 2.0 Flash Thinking Experimental 01-21", "google/gemini-2.0-flash-thinking-exp:free", 1048576),
165
  ("Google: Gemini Flash 2.0 Experimental", "google/gemini-2.0-flash-exp:free", 1048576),
166
  ("Google: Gemini Pro 2.5 Experimental", "google/gemini-2.5-pro-exp-03-25:free", 1000000),
 
182
  ]},
183
  ]
184
 
185
+ # Flatten OpenRouter model list for easier access
186
+ OPENROUTER_ALL_MODELS = []
187
+ for category in OPENROUTER_MODELS:
188
  for model in category["models"]:
189
+ if model not in OPENROUTER_ALL_MODELS: # Avoid duplicates
190
+ OPENROUTER_ALL_MODELS.append(model)
191
 
192
+ # OPENAI MODELS
193
+ OPENAI_MODELS = {
194
+ "gpt-3.5-turbo": 16385,
195
+ "gpt-3.5-turbo-0125": 16385,
196
+ "gpt-3.5-turbo-1106": 16385,
197
+ "gpt-3.5-turbo-instruct": 4096,
198
+ "gpt-4": 8192,
199
+ "gpt-4-0314": 8192,
200
+ "gpt-4-0613": 8192,
201
+ "gpt-4-turbo": 128000,
202
+ "gpt-4-turbo-2024-04-09": 128000,
203
+ "gpt-4-turbo-preview": 128000,
204
+ "gpt-4-0125-preview": 128000,
205
+ "gpt-4-1106-preview": 128000,
206
+ "gpt-4o": 128000,
207
+ "gpt-4o-2024-11-20": 128000,
208
+ "gpt-4o-2024-08-06": 128000,
209
+ "gpt-4o-2024-05-13": 128000,
210
+ "chatgpt-4o-latest": 128000,
211
+ "gpt-4o-mini": 128000,
212
+ "gpt-4o-mini-2024-07-18": 128000,
213
+ "gpt-4o-realtime-preview": 128000,
214
+ "gpt-4o-realtime-preview-2024-10-01": 128000,
215
+ "gpt-4o-audio-preview": 128000,
216
+ "gpt-4o-audio-preview-2024-10-01": 128000,
217
+ "o1-preview": 128000,
218
+ "o1-preview-2024-09-12": 128000,
219
+ "o1-mini": 128000,
220
+ "o1-mini-2024-09-12": 128000,
221
+ }
222
 
223
+ # HUGGINGFACE MODELS
224
+ HUGGINGFACE_MODELS = {
225
+ "microsoft/phi-3-mini-4k-instruct": 4096,
226
+ "microsoft/Phi-3-mini-128k-instruct": 131072,
227
+ "HuggingFaceH4/zephyr-7b-beta": 8192,
228
+ "deepseek-ai/DeepSeek-Coder-V2-Instruct": 8192,
229
+ "mistralai/Mistral-7B-Instruct-v0.3": 32768,
230
+ "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": 32768,
231
+ "microsoft/Phi-3.5-mini-instruct": 4096,
232
+ "HuggingFaceTB/SmolLM2-1.7B-Instruct": 2048,
233
+ "google/gemma-2-2b-it": 2048,
234
+ "openai-community/gpt2": 1024,
235
+ "microsoft/phi-2": 2048,
236
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0": 2048,
237
+ "VAGOsolutions/Llama-3-SauerkrautLM-8b-Instruct": 2048,
238
+ "VAGOsolutions/Llama-3.1-SauerkrautLM-8b-Instruct": 4096,
239
+ "VAGOsolutions/SauerkrautLM-Nemo-12b-Instruct": 4096,
240
+ "openGPT-X/Teuken-7B-instruct-research-v0.4": 4096,
241
+ "Qwen/Qwen2.5-7B-Instruct": 131072,
242
+ "tiiuae/falcon-7b-instruct": 8192,
243
+ "Qwen/QwQ-32B-preview": 32768,
244
+ }
245
 
246
+ # GROQ MODELS - We'll populate this dynamically
247
+ DEFAULT_GROQ_MODELS = {
248
+ "gemma2-9b-it": 8192,
249
+ "gemma-7b-it": 8192,
250
+ "llama-3.3-70b-versatile": 131072,
251
+ "llama-3.1-70b-versatile": 131072,
252
+ "llama-3.1-8b-instant": 131072,
253
+ "llama-guard-3-8b": 8192,
254
+ "llama3-70b-8192": 8192,
255
+ "llama3-8b-8192": 8192,
256
+ "mixtral-8x7b-32768": 32768,
257
+ "llama3-groq-70b-8192-tool-use-preview": 8192,
258
+ "llama3-groq-8b-8192-tool-use-preview": 8192,
259
+ "llama-3.3-70b-specdec": 131072,
260
+ "llama-3.1-70b-specdec": 131072,
261
+ "llama-3.2-1b-preview": 131072,
262
+ "llama-3.2-3b-preview": 131072,
263
+ }
 
 
 
 
 
 
 
 
264
 
265
+ # COHERE MODELS
266
+ COHERE_MODELS = {
267
+ "command-r-plus-08-2024": 131072,
268
+ "command-r-plus-04-2024": 131072,
269
+ "command-r-plus": 131072,
270
+ "command-r-08-2024": 131072,
271
+ "command-r-03-2024": 131072,
272
+ "command-r": 131072,
273
+ "command": 4096,
274
+ "command-nightly": 131072,
275
+ "command-light": 4096,
276
+ "command-light-nightly": 4096,
277
+ "c4ai-aya-expanse-8b": 8192,
278
+ "c4ai-aya-expanse-32b": 131072,
279
+ }
280
 
281
+ # GLHF MODELS
282
+ GLHF_MODELS = {
283
+ "mistralai/Mixtral-8x7B-Instruct-v0.1": 32768,
284
+ "01-ai/Yi-34B-Chat": 32768,
285
+ "mistralai/Mistral-7B-Instruct-v0.3": 32768,
286
+ "microsoft/phi-3-mini-4k-instruct": 4096,
287
+ "microsoft/Phi-3.5-mini-instruct": 4096,
288
+ "microsoft/Phi-3-mini-128k-instruct": 131072,
289
+ "HuggingFaceH4/zephyr-7b-beta": 8192,
290
+ "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": 32768,
291
+ "google/gemma-2-2b-it": 2048,
292
+ "microsoft/phi-2": 2048,
293
+ }
294
+
295
+ # ==========================================================
296
+ # HELPER FUNCTIONS
297
+ # ==========================================================
298
+
299
+ def fetch_groq_models():
300
+ """Fetch available Groq models with proper error handling"""
301
+ try:
302
+ if not HAS_GROQ or not GROQ_API_KEY:
303
+ logger.warning("Groq client not available or no API key. Using default model list.")
304
+ return DEFAULT_GROQ_MODELS
305
+
306
+ client = Groq(api_key=GROQ_API_KEY)
307
+ models = client.models.list()
308
+
309
+ # Create dictionary of model_id -> context size
310
+ model_dict = {}
311
+ for model in models.data:
312
+ model_id = model.id
313
+ # Map known context sizes or use a default
314
+ if "llama-3" in model_id and "70b" in model_id:
315
+ context_size = 131072
316
+ elif "llama-3" in model_id and "8b" in model_id:
317
+ context_size = 131072
318
+ elif "mixtral" in model_id:
319
+ context_size = 32768
320
+ elif "gemma" in model_id:
321
+ context_size = 8192
322
+ else:
323
+ context_size = 8192 # Default assumption
324
+
325
+ model_dict[model_id] = context_size
326
+
327
+ # Ensure we have models by combining with defaults
328
+ if not model_dict:
329
+ return DEFAULT_GROQ_MODELS
330
+ return {**DEFAULT_GROQ_MODELS, **model_dict}
331
+
332
+ except Exception as e:
333
+ logger.error(f"Error fetching Groq models: {e}")
334
+ return DEFAULT_GROQ_MODELS
335
+
336
+ # Initialize Groq models
337
+ GROQ_MODELS = fetch_groq_models()
338
 
339
  def encode_image_to_base64(image_path):
340
  """Encode an image file to base64 string"""
 
376
  file_extension = file_path.split('.')[-1].lower()
377
 
378
  if file_extension == 'pdf':
379
+ if HAS_PYPDF2:
380
  text = ""
381
  with open(file_path, 'rb') as file:
382
  pdf_reader = PyPDF2.PdfReader(file)
 
484
  file_paths.append(file.name)
485
  return file_paths
486
 
487
+ def filter_models(provider, search_term):
488
+ """Filter models based on search term and provider"""
489
+ if provider == "OpenRouter":
490
+ all_models = [model[0] for model in OPENROUTER_ALL_MODELS]
491
+ elif provider == "OpenAI":
492
+ all_models = list(OPENAI_MODELS.keys())
493
+ elif provider == "HuggingFace":
494
+ all_models = list(HUGGINGFACE_MODELS.keys())
495
+ elif provider == "Groq":
496
+ all_models = list(GROQ_MODELS.keys())
497
+ elif provider == "Cohere":
498
+ all_models = list(COHERE_MODELS.keys())
499
+ elif provider == "GLHF":
500
+ all_models = list(GLHF_MODELS.keys())
501
+ else:
502
+ return [], None
503
+
504
+ if not search_term:
505
+ return all_models, all_models[0] if all_models else None
506
+
507
+ filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
508
+
509
+ if filtered_models:
510
+ return filtered_models, filtered_models[0]
511
+ else:
512
+ return
513
+
514
+ return all_models, all_models[0] if all_models else None
515
+
516
+ def get_model_info(provider, model_choice):
517
+ """Get model ID and context size based on provider and model name"""
518
+ if provider == "OpenRouter":
519
+ for name, model_id, ctx_size in OPENROUTER_ALL_MODELS:
520
+ if name == model_choice:
521
+ return model_id, ctx_size
522
+ elif provider == "OpenAI":
523
+ if model_choice in OPENAI_MODELS:
524
+ return model_choice, OPENAI_MODELS[model_choice]
525
+ elif provider == "HuggingFace":
526
+ if model_choice in HUGGINGFACE_MODELS:
527
+ return model_choice, HUGGINGFACE_MODELS[model_choice]
528
+ elif provider == "Groq":
529
+ if model_choice in GROQ_MODELS:
530
+ return model_choice, GROQ_MODELS[model_choice]
531
+ elif provider == "Cohere":
532
+ if model_choice in COHERE_MODELS:
533
+ return model_choice, COHERE_MODELS[model_choice]
534
+ elif provider == "GLHF":
535
+ if model_choice in GLHF_MODELS:
536
+ return model_choice, GLHF_MODELS[model_choice]
537
+
538
  return None, 0
539
 
540
+ def update_context_display(provider, model_name):
541
+ """Update context size display for the selected model"""
542
+ _, ctx_size = get_model_info(provider, model_name)
543
+ return f"{ctx_size:,}" if ctx_size else "Unknown"
 
 
544
 
545
+ def update_model_info(provider, model_name):
546
+ """Generate HTML info display for the selected model"""
547
+ model_id, ctx_size = get_model_info(provider, model_name)
548
+ if not model_id:
549
+ return "<p>Model information not available</p>"
550
+
551
+ # Check if this is a vision model
552
+ is_vision_model = False
553
+
554
+ # For OpenRouter, check the vision models category
555
+ if provider == "OpenRouter":
556
+ for cat in OPENROUTER_MODELS:
557
+ if cat["category"] == "Vision Models":
558
+ if any(m[0] == model_name for m in cat["models"]):
559
+ is_vision_model = True
560
+ break
561
+ # For other providers, use heuristics
562
+ elif provider == "OpenAI" and any(x in model_name.lower() for x in ["gpt-4", "gpt-4o"]):
563
+ is_vision_model = True
564
+ elif provider == "HuggingFace" and any(x in model_name.lower() for x in ["vl", "vision"]):
565
+ is_vision_model = True
566
+
567
+ vision_badge = '<span style="background-color: #4CAF50; color: white; padding: 3px 6px; border-radius: 3px; font-size: 0.8em; margin-left: 5px;">Vision</span>' if is_vision_model else ''
568
+
569
+ # For OpenRouter, show the model ID
570
+ model_id_html = f"<p><strong>Model ID:</strong> {model_id}</p>" if provider == "OpenRouter" else ""
571
+
572
+ # For others, the ID is the same as the name
573
+ if provider != "OpenRouter":
574
+ model_id_html = ""
575
+
576
+ return f"""
577
+ <div class="model-info">
578
+ <h3>{model_name} {vision_badge}</h3>
579
+ {model_id_html}
580
+ <p><strong>Context Size:</strong> {ctx_size:,} tokens</p>
581
+ <p><strong>Provider:</strong> {provider}</p>
582
+ {f'<p><strong>Features:</strong> Supports image understanding</p>' if is_vision_model else ''}
583
+ </div>
584
+ """
585
+
586
+ # ==========================================================
587
+ # API HANDLERS
588
+ # ==========================================================
589
+
590
+ def call_openrouter_api(payload, api_key_override=None):
591
  """Make a call to OpenRouter API with error handling"""
592
  try:
593
+ api_key = api_key_override if api_key_override else OPENROUTER_API_KEY
594
+ if not api_key:
595
+ raise ValueError("OpenRouter API key is required")
596
+
597
  response = requests.post(
598
  "https://openrouter.ai/api/v1/chat/completions",
599
  headers={
600
  "Content-Type": "application/json",
601
+ "Authorization": f"Bearer {api_key}",
602
+ "HTTP-Referer": "https://huggingface.co/spaces/user/MultiProviderCrispChat"
603
  },
604
  json=payload,
605
  timeout=180 # Longer timeout for document processing
606
  )
607
  return response
608
  except requests.RequestException as e:
609
+ logger.error(f"OpenRouter API request error: {str(e)}")
610
  raise e
611
 
612
+ def call_openai_api(payload, api_key_override=None):
613
+ """Make a call to OpenAI API with error handling"""
614
  try:
615
+ if not HAS_OPENAI:
616
+ raise ImportError("OpenAI package not installed")
617
+
618
+ api_key = api_key_override if api_key_override else OPENAI_API_KEY
619
+ if not api_key:
620
+ raise ValueError("OpenAI API key is required")
621
+
622
+ client = openai.OpenAI(api_key=api_key)
 
 
 
 
 
 
 
 
 
 
 
623
 
624
+ # Extract parameters from payload
625
+ model = payload.get("model", "gpt-3.5-turbo")
626
+ messages = payload.get("messages", [])
627
+ temperature = payload.get("temperature", 0.7)
628
+ max_tokens = payload.get("max_tokens", 1000)
629
+ stream = payload.get("stream", False)
630
+ top_p = payload.get("top_p", 0.9)
631
+ presence_penalty = payload.get("presence_penalty", 0)
632
+ frequency_penalty = payload.get("frequency_penalty", 0)
633
+
634
+ # Handle response format if specified
635
+ response_format = None
636
+ if payload.get("response_format") == "json_object":
637
+ response_format = {"type": "json_object"}
638
+
639
+ # Create completion
640
+ response = client.chat.completions.create(
641
+ model=model,
642
+ messages=messages,
643
+ temperature=temperature,
644
+ max_tokens=max_tokens,
645
+ stream=stream,
646
+ top_p=top_p,
647
+ presence_penalty=presence_penalty,
648
+ frequency_penalty=frequency_penalty,
649
+ response_format=response_format
650
+ )
651
+
652
+ return response
653
+ except Exception as e:
654
+ logger.error(f"OpenAI API error: {str(e)}")
655
+ raise e
656
+
657
+ def call_huggingface_api(payload, api_key_override=None):
658
+ """Make a call to HuggingFace API with error handling"""
659
+ try:
660
+ if not HAS_HF:
661
+ raise ImportError("HuggingFace hub not installed")
662
+
663
+ api_key = api_key_override if api_key_override else HF_API_KEY
664
+
665
+ # Extract parameters from payload
666
+ model_id = payload.get("model", "mistralai/Mistral-7B-Instruct-v0.3")
667
+ messages = payload.get("messages", [])
668
+ temperature = payload.get("temperature", 0.7)
669
+ max_tokens = payload.get("max_tokens", 500)
670
+
671
+ # Create a prompt from messages
672
+ prompt = ""
673
+ for msg in messages:
674
+ role = msg["role"].upper()
675
+ content = msg["content"]
676
+
677
+ # Handle multimodal content
678
+ if isinstance(content, list):
679
+ text_parts = []
680
+ for item in content:
681
+ if item["type"] == "text":
682
+ text_parts.append(item["text"])
683
+ content = "\n".join(text_parts)
684
+
685
+ prompt += f"{role}: {content}\n"
686
+
687
+ prompt += "ASSISTANT: "
688
+
689
+ # Create client with or without API key
690
+ client = InferenceClient(token=api_key) if api_key else InferenceClient()
691
+
692
+ # Generate response
693
+ response = client.text_generation(
694
+ prompt,
695
+ model=model_id,
696
+ max_new_tokens=max_tokens,
697
+ temperature=temperature,
698
+ repetition_penalty=1.1
699
+ )
700
+
701
+ return {"generated_text": str(response)}
702
+ except Exception as e:
703
+ logger.error(f"HuggingFace API error: {str(e)}")
704
+ raise e
705
+
706
+ def call_groq_api(payload, api_key_override=None):
707
+ """Make a call to Groq API with error handling"""
708
+ try:
709
+ if not HAS_GROQ:
710
+ raise ImportError("Groq client not installed")
711
+
712
+ api_key = api_key_override if api_key_override else GROQ_API_KEY
713
+ if not api_key:
714
+ raise ValueError("Groq API key is required")
715
+
716
+ client = Groq(api_key=api_key)
717
+
718
+ # Extract parameters from payload
719
+ model = payload.get("model", "llama-3.1-8b-instant")
720
+ messages = payload.get("messages", [])
721
+ temperature = payload.get("temperature", 0.7)
722
+ max_tokens = payload.get("max_tokens", 1000)
723
+ stream = payload.get("stream", False)
724
+ top_p = payload.get("top_p", 0.9)
725
+
726
+ # Create completion
727
+ response = client.chat.completions.create(
728
+ model=model,
729
+ messages=messages,
730
+ temperature=temperature,
731
+ max_tokens=max_tokens,
732
+ stream=stream,
733
+ top_p=top_p
734
+ )
735
+
736
+ return response
737
+ except Exception as e:
738
+ logger.error(f"Groq API error: {str(e)}")
739
+ raise e
740
+
741
+ def call_cohere_api(payload, api_key_override=None):
742
+ """Make a call to Cohere API with error handling"""
743
+ try:
744
+ if not HAS_COHERE:
745
+ raise ImportError("Cohere package not installed")
746
+
747
+ api_key = api_key_override if api_key_override else COHERE_API_KEY
748
+ if not api_key:
749
+ raise ValueError("Cohere API key is required")
750
+
751
+ client = cohere.Client(api_key=api_key)
752
+
753
+ # Extract parameters from payload
754
+ model = payload.get("model", "command-r-plus")
755
+ messages = payload.get("messages", [])
756
+ temperature = payload.get("temperature", 0.7)
757
+ max_tokens = payload.get("max_tokens", 1000)
758
+
759
+ # Format messages for Cohere
760
+ chat_history = []
761
+ user_message = ""
762
+
763
+ for msg in messages:
764
+ if msg["role"] == "system":
765
+ # For system message, we'll prepend to the user's first message
766
+ system_content = msg["content"]
767
+ if isinstance(system_content, list): # Handle multimodal content
768
+ system_parts = []
769
+ for item in system_content:
770
+ if item["type"] == "text":
771
+ system_parts.append(item["text"])
772
+ system_content = "\n".join(system_parts)
773
+ user_message = f"System: {system_content}\n\n" + user_message
774
+ elif msg["role"] == "user":
775
+ content = msg["content"]
776
+ # Handle multimodal content
777
+ if isinstance(content, list):
778
+ text_parts = []
779
+ for item in content:
780
+ if item["type"] == "text":
781
+ text_parts.append(item["text"])
782
+ content = "\n".join(text_parts)
783
+ user_message = content
784
+ elif msg["role"] == "assistant":
785
+ content = msg["content"]
786
+ if content:
787
+ chat_history.append({"role": "ASSISTANT", "message": content})
788
+
789
+ # Create chat completion
790
+ response = client.chat(
791
+ message=user_message,
792
+ chat_history=chat_history,
793
+ model=model,
794
+ temperature=temperature,
795
+ max_tokens=max_tokens
796
+ )
797
+
798
+ return response
799
+ except Exception as e:
800
+ logger.error(f"Cohere API error: {str(e)}")
801
+ raise e
802
+
803
+ def call_glhf_api(payload, api_key_override=None):
804
+ """Make a call to GLHF API with error handling"""
805
+ try:
806
+ if not HAS_OPENAI:
807
+ raise ImportError("OpenAI package not installed (required for GLHF API)")
808
+
809
+ api_key = api_key_override if api_key_override else GLHF_API_KEY
810
+ if not api_key:
811
+ raise ValueError("GLHF API key is required")
812
+
813
+ client = openai.OpenAI(
814
+ api_key=api_key,
815
+ base_url="https://glhf.chat/api/openai/v1"
816
+ )
817
+
818
+ # Extract parameters from payload
819
+ model_name = payload.get("model", "mistralai/Mistral-7B-Instruct-v0.3")
820
+ # Add "hf:" prefix if not already there
821
+ if not model_name.startswith("hf:"):
822
+ model = f"hf:{model_name}"
823
+ else:
824
+ model = model_name
825
+
826
+ messages = payload.get("messages", [])
827
+ temperature = payload.get("temperature", 0.7)
828
+ max_tokens = payload.get("max_tokens", 1000)
829
+ stream = payload.get("stream", False)
830
+
831
+ # Create completion
832
+ response = client.chat.completions.create(
833
+ model=model,
834
+ messages=messages,
835
+ temperature=temperature,
836
+ max_tokens=max_tokens,
837
+ stream=stream
838
+ )
839
+
840
+ return response
841
+ except Exception as e:
842
+ logger.error(f"GLHF API error: {str(e)}")
843
+ raise e
844
+
845
+ def extract_ai_response(result, provider):
846
+ """Extract AI response based on provider format"""
847
+ try:
848
+ if provider == "OpenRouter":
849
+ if isinstance(result, dict):
850
+ if "choices" in result and len(result["choices"]) > 0:
851
+ if "message" in result["choices"][0]:
852
+ message = result["choices"][0]["message"]
853
+ if message.get("reasoning") and not message.get("content"):
854
+ reasoning = message.get("reasoning")
855
+ lines = reasoning.strip().split('\n')
856
+ for line in lines:
857
+ if line and not line.startswith('I should') and not line.startswith('Let me'):
858
+ return line.strip()
859
+ for line in lines:
860
+ if line.strip():
861
+ return line.strip()
862
+ return message.get("content", "")
863
+ elif "delta" in result["choices"][0]:
864
+ return result["choices"][0]["delta"].get("content", "")
865
+
866
+ elif provider == "OpenAI":
867
+ if hasattr(result, "choices") and len(result.choices) > 0:
868
+ return result.choices[0].message.content
869
+
870
+ elif provider == "HuggingFace":
871
+ return result.get("generated_text", "")
872
+
873
+ elif provider == "Groq":
874
+ if hasattr(result, "choices") and len(result.choices) > 0:
875
+ return result.choices[0].message.content
876
+
877
+ elif provider == "Cohere":
878
+ if hasattr(result, "text"):
879
+ return result.text
880
+
881
+ elif provider == "GLHF":
882
+ if hasattr(result, "choices") and len(result.choices) > 0:
883
+ return result.choices[0].message.content
884
+
885
+ logger.error(f"Unexpected response structure from {provider}: {result}")
886
+ return f"Error: Could not extract response from {provider} API result"
887
  except Exception as e:
888
  logger.error(f"Error extracting AI response: {str(e)}")
889
  return f"Error: {str(e)}"
890
 
891
+ # ==========================================================
892
+ # STREAMING HANDLERS
893
+ # ==========================================================
894
+
895
+ def openrouter_streaming_handler(response, chatbot, message_idx, message):
896
  try:
897
  # First add the user message if needed
898
  if len(chatbot) == message_idx:
899
+ chatbot.append([message, ""])
 
900
 
901
  for line in response.iter_lines():
902
  if not line:
 
915
  if "choices" in chunk and len(chunk["choices"]) > 0:
916
  delta = chunk["choices"][0].get("delta", {})
917
  if "content" in delta and delta["content"]:
918
+ # Update the current response
919
+ chatbot[-1][1] += delta["content"]
920
  yield chatbot
921
  except json.JSONDecodeError:
922
  logger.error(f"Failed to parse JSON from chunk: {data}")
 
924
  logger.error(f"Error in streaming handler: {str(e)}")
925
  # Add error message to the current response
926
  if len(chatbot) > message_idx:
927
+ chatbot[-1][1] += f"\n\nError during streaming: {str(e)}"
928
  yield chatbot
929
 
930
+ def openai_streaming_handler(response, chatbot, message_idx, message):
931
+ try:
932
+ # First add the user message if needed
933
+ if len(chatbot) == message_idx:
934
+ chatbot.append([message, ""])
935
+
936
+ full_response = ""
937
+ for chunk in response:
938
+ if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
939
+ content = chunk.choices[0].delta.content
940
+ full_response += content
941
+ chatbot[-1][1] = full_response
942
+ yield chatbot
943
+
944
+ except Exception as e:
945
+ logger.error(f"Error in OpenAI streaming handler: {str(e)}")
946
+ # Add error message to the current response
947
+ chatbot[-1][1] += f"\n\nError during streaming: {str(e)}"
948
+ yield chatbot
949
+
950
+ def groq_streaming_handler(response, chatbot, message_idx, message):
951
+ try:
952
+ # First add the user message if needed
953
+ if len(chatbot) == message_idx:
954
+ chatbot.append([message, ""])
955
+
956
+ full_response = ""
957
+ for chunk in response:
958
+ if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
959
+ content = chunk.choices[0].delta.content
960
+ full_response += content
961
+ chatbot[-1][1] = full_response
962
+ yield chatbot
963
+
964
+ except Exception as e:
965
+ logger.error(f"Error in Groq streaming handler: {str(e)}")
966
+ # Add error message to the current response
967
+ chatbot[-1][1] += f"\n\nError during streaming: {str(e)}"
968
+ yield chatbot
969
+
970
+ def glhf_streaming_handler(response, chatbot, message_idx, message):
971
+ try:
972
+ # First add the user message if needed
973
+ if len(chatbot) == message_idx:
974
+ chatbot.append([message, ""])
975
+
976
+ full_response = ""
977
+ for chunk in response:
978
+ if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
979
+ content = chunk.choices[0].delta.content
980
+ full_response += content
981
+ chatbot[-1][1] = full_response
982
+ yield chatbot
983
+
984
+ except Exception as e:
985
+ logger.error(f"Error in GLHF streaming handler: {str(e)}")
986
+ # Add error message to the current response
987
+ chatbot[-1][1] += f"\n\nError during streaming: {str(e)}"
988
+ yield chatbot
989
+
990
+ # ==========================================================
991
+ # MAIN FUNCTION TO ASK AI
992
+ # ==========================================================
993
+
994
+ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, top_p,
995
+ frequency_penalty, presence_penalty, repetition_penalty, top_k, min_p,
996
+ seed, top_a, stream_output, response_format, images, documents,
997
+ reasoning_effort, system_message, transforms, api_key_override=None):
998
+ """Enhanced AI query function with support for multiple providers"""
999
  # Validate input
1000
  if not message.strip() and not images and not documents:
1001
  return history
1002
 
 
 
 
 
 
 
 
1003
  # Copy history to new list to avoid modifying the original
1004
  chat_history = list(history)
1005
 
 
1019
  # Add current message
1020
  messages.append({"role": "user", "content": content})
1021
 
1022
+ # Common parameters for all providers
1023
+ common_params = {
 
 
1024
  "temperature": temperature,
1025
  "max_tokens": max_tokens,
1026
  "top_p": top_p,
 
1029
  "stream": stream_output
1030
  }
1031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1032
  try:
1033
+ # Process based on provider
1034
+ if provider == "OpenRouter":
1035
+ # Get model ID from registry
1036
+ model_id, _ = get_model_info(provider, model_choice)
1037
+ if not model_id:
1038
+ error_message = f"Error: Model '{model_choice}' not found in OpenRouter"
1039
+ chat_history.append([message, error_message])
1040
+ return chat_history
1041
+
1042
+ # Build OpenRouter payload
1043
+ payload = {
1044
+ "model": model_id,
1045
+ "messages": messages,
1046
+ **common_params
1047
+ }
1048
 
1049
+ # Add optional parameters if set
1050
+ if repetition_penalty != 1.0:
1051
+ payload["repetition_penalty"] = repetition_penalty
 
1052
 
1053
+ if top_k > 0:
1054
+ payload["top_k"] = top_k
 
 
 
 
1055
 
1056
+ if min_p > 0:
1057
+ payload["min_p"] = min_p
1058
 
1059
+ if seed > 0:
1060
+ payload["seed"] = seed
1061
+
1062
+ if top_a > 0:
1063
+ payload["top_a"] = top_a
1064
+
1065
+ # Add response format if JSON is requested
1066
+ if response_format == "json_object":
1067
+ payload["response_format"] = {"type": "json_object"}
1068
+
1069
+ # Add reasoning if selected
1070
+ if reasoning_effort != "none":
1071
+ payload["reasoning"] = {
1072
+ "effort": reasoning_effort
1073
+ }
1074
+
1075
+ # Add transforms if selected
1076
+ if transforms:
1077
+ payload["transforms"] = transforms
1078
+
1079
+ # Call OpenRouter API
1080
+ logger.info(f"Sending request to OpenRouter model: {model_id}")
1081
+
1082
+ response = call_openrouter_api(payload, api_key_override)
1083
+
1084
+ # Handle streaming response
1085
+ if stream_output and response.status_code == 200:
1086
+ # Add empty response slot to history
1087
+ chat_history.append([message, ""])
1088
+
1089
+ # Set up generator for streaming updates
1090
+ def streaming_generator():
1091
+ for updated_history in openrouter_streaming_handler(response, chat_history, len(chat_history) - 1, message):
1092
+ yield updated_history
1093
+
1094
+ return streaming_generator()
1095
+
1096
+ # Handle normal response
1097
+ elif response.status_code == 200:
1098
+ result = response.json()
1099
+ logger.info(f"Response content: {result}")
1100
+
1101
+ # Extract AI response
1102
+ ai_response = extract_ai_response(result, provider)
1103
+
1104
+ # Add response to history
1105
+ chat_history.append([message, ai_response])
1106
+ return chat_history
1107
+
1108
+ # Handle error response
1109
+ else:
1110
+ error_message = f"Error: Status code {response.status_code}"
1111
+ try:
1112
+ response_data = response.json()
1113
+ error_message += f"\n\nDetails: {json.dumps(response_data, indent=2)}"
1114
+ except:
1115
+ error_message += f"\n\nResponse: {response.text}"
1116
+
1117
+ logger.error(error_message)
1118
+ chat_history.append([message, error_message])
1119
+ return chat_history
1120
+
1121
+ elif provider == "OpenAI":
1122
+ # Get model ID from registry
1123
+ model_id, _ = get_model_info(provider, model_choice)
1124
+ if not model_id:
1125
+ error_message = f"Error: Model '{model_choice}' not found in OpenAI"
1126
+ chat_history.append([message, error_message])
1127
+ return chat_history
1128
+
1129
+ # Build OpenAI payload
1130
+ payload = {
1131
+ "model": model_id,
1132
+ "messages": messages,
1133
+ **common_params
1134
+ }
1135
+
1136
+ # Add response format if JSON is requested
1137
+ if response_format == "json_object":
1138
+ payload["response_format"] = {"type": "json_object"}
1139
+
1140
+ # Call OpenAI API
1141
+ logger.info(f"Sending request to OpenAI model: {model_id}")
1142
+
1143
+ try:
1144
+ response = call_openai_api(payload, api_key_override)
1145
+
1146
+ # Handle streaming response
1147
+ if stream_output:
1148
+ # Add empty response slot to history
1149
+ chat_history.append([message, ""])
1150
+
1151
+ # Set up generator for streaming updates
1152
+ def streaming_generator():
1153
+ for updated_history in openai_streaming_handler(response, chat_history, len(chat_history) - 1, message):
1154
+ yield updated_history
1155
+
1156
+ return streaming_generator()
1157
+
1158
+ # Handle normal response
1159
+ else:
1160
+ ai_response = extract_ai_response(response, provider)
1161
+ chat_history.append([message, ai_response])
1162
+ return chat_history
1163
+ except Exception as e:
1164
+ error_message = f"OpenAI API Error: {str(e)}"
1165
+ logger.error(error_message)
1166
+ chat_history.append([message, error_message])
1167
+ return chat_history
1168
+
1169
+ elif provider == "HuggingFace":
1170
+ # Get model ID from registry
1171
+ model_id, _ = get_model_info(provider, model_choice)
1172
+ if not model_id:
1173
+ error_message = f"Error: Model '{model_choice}' not found in HuggingFace"
1174
+ chat_history.append([message, error_message])
1175
+ return chat_history
1176
+
1177
+ # Build HuggingFace payload
1178
+ payload = {
1179
+ "model": model_id,
1180
+ "messages": messages,
1181
+ "temperature": temperature,
1182
+ "max_tokens": max_tokens
1183
+ }
1184
+
1185
+ # Call HuggingFace API
1186
+ logger.info(f"Sending request to HuggingFace model: {model_id}")
1187
 
 
 
 
 
 
 
 
 
1188
  try:
1189
+ response = call_huggingface_api(payload, api_key_override)
1190
+
1191
+ # Extract response
1192
+ ai_response = extract_ai_response(response, provider)
1193
+ chat_history.append([message, ai_response])
1194
+ return chat_history
1195
+ except Exception as e:
1196
+ error_message = f"HuggingFace API Error: {str(e)}"
1197
+ logger.error(error_message)
1198
+ chat_history.append([message, error_message])
1199
+ return chat_history
1200
+
1201
+ elif provider == "Groq":
1202
+ # Get model ID from registry
1203
+ model_id, _ = get_model_info(provider, model_choice)
1204
+ if not model_id:
1205
+ error_message = f"Error: Model '{model_choice}' not found in Groq"
1206
+ chat_history.append([message, error_message])
1207
+ return chat_history
1208
+
1209
+ # Build Groq payload
1210
+ payload = {
1211
+ "model": model_id,
1212
+ "messages": messages,
1213
+ "temperature": temperature,
1214
+ "max_tokens": max_tokens,
1215
+ "top_p": top_p,
1216
+ "stream": stream_output
1217
+ }
1218
+
1219
+ # Call Groq API
1220
+ logger.info(f"Sending request to Groq model: {model_id}")
1221
+
1222
+ try:
1223
+ response = call_groq_api(payload, api_key_override)
1224
+
1225
+ # Handle streaming response
1226
+ if stream_output:
1227
+ # Add empty response slot to history
1228
+ chat_history.append([message, ""])
1229
+
1230
+ # Set up generator for streaming updates
1231
+ def streaming_generator():
1232
+ for updated_history in groq_streaming_handler(response, chat_history, len(chat_history) - 1, message):
1233
+ yield updated_history
1234
+
1235
+ return streaming_generator()
1236
+
1237
+ # Handle normal response
1238
+ else:
1239
+ ai_response = extract_ai_response(response, provider)
1240
+ chat_history.append([message, ai_response])
1241
+ return chat_history
1242
+ except Exception as e:
1243
+ error_message = f"Groq API Error: {str(e)}"
1244
+ logger.error(error_message)
1245
+ chat_history.append([message, error_message])
1246
+ return chat_history
1247
+
1248
+ elif provider == "Cohere":
1249
+ # Get model ID from registry
1250
+ model_id, _ = get_model_info(provider, model_choice)
1251
+ if not model_id:
1252
+ error_message = f"Error: Model '{model_choice}' not found in Cohere"
1253
+ chat_history.append([message, error_message])
1254
+ return chat_history
1255
+
1256
+ # Build Cohere payload (doesn't support streaming the same way)
1257
+ payload = {
1258
+ "model": model_id,
1259
+ "messages": messages,
1260
+ "temperature": temperature,
1261
+ "max_tokens": max_tokens
1262
+ }
1263
+
1264
+ # Call Cohere API
1265
+ logger.info(f"Sending request to Cohere model: {model_id}")
1266
+
1267
+ try:
1268
+ response = call_cohere_api(payload, api_key_override)
1269
+
1270
+ # Extract response
1271
+ ai_response = extract_ai_response(response, provider)
1272
+ chat_history.append([message, ai_response])
1273
+ return chat_history
1274
+ except Exception as e:
1275
+ error_message = f"Cohere API Error: {str(e)}"
1276
+ logger.error(error_message)
1277
+ chat_history.append([message, error_message])
1278
+ return chat_history
1279
+
1280
+ elif provider == "GLHF":
1281
+ # Get model ID from registry
1282
+ model_id, _ = get_model_info(provider, model_choice)
1283
+ if not model_id:
1284
+ error_message = f"Error: Model '{model_choice}' not found in GLHF"
1285
+ chat_history.append([message, error_message])
1286
+ return chat_history
1287
+
1288
+ # Build GLHF payload
1289
+ payload = {
1290
+ "model": model_id, # The hf: prefix will be added in the API call
1291
+ "messages": messages,
1292
+ "temperature": temperature,
1293
+ "max_tokens": max_tokens,
1294
+ "stream": stream_output
1295
+ }
1296
 
1297
+ # Call GLHF API
1298
+ logger.info(f"Sending request to GLHF model: {model_id}")
1299
+
1300
+ try:
1301
+ response = call_glhf_api(payload, api_key_override)
1302
+
1303
+ # Handle streaming response
1304
+ if stream_output:
1305
+ # Add empty response slot to history
1306
+ chat_history.append([message, ""])
1307
+
1308
+ # Set up generator for streaming updates
1309
+ def streaming_generator():
1310
+ for updated_history in glhf_streaming_handler(response, chat_history, len(chat_history) - 1, message):
1311
+ yield updated_history
1312
+
1313
+ return streaming_generator()
1314
+
1315
+ # Handle normal response
1316
+ else:
1317
+ ai_response = extract_ai_response(response, provider)
1318
+ chat_history.append([message, ai_response])
1319
+ return chat_history
1320
+ except Exception as e:
1321
+ error_message = f"GLHF API Error: {str(e)}"
1322
+ logger.error(error_message)
1323
+ chat_history.append([message, error_message])
1324
+ return chat_history
1325
+
1326
+ else:
1327
+ error_message = f"Error: Unsupported provider '{provider}'"
1328
  chat_history.append([message, error_message])
1329
  return chat_history
1330
 
 
1338
  """Reset all inputs"""
1339
  return [], "", [], [], 0.7, 1000, 0.8, 0.0, 0.0, 1.0, 40, 0.1, 0, 0.0, False, "default", "none", "", []
1340
 
1341
+ # ==========================================================
1342
+ # UI CREATION
1343
+ # ==========================================================
1344
+
1345
  def create_app():
1346
+ """Create the Multi-Provider CrispChat Gradio application"""
1347
  with gr.Blocks(
1348
+ title="Multi-Provider CrispChat",
1349
  css="""
1350
  .context-size {
1351
  font-size: 0.9em;
 
1370
  font-size: 0.8em;
1371
  margin-left: 5px;
1372
  }
1373
+ .provider-selection {
1374
+ margin-bottom: 10px;
1375
+ padding: 10px;
1376
+ border-radius: 5px;
1377
+ background-color: #f5f5f5;
1378
+ }
1379
  """
1380
  ) as demo:
1381
  gr.Markdown("""
1382
+ # 🤖 Multi-Provider CrispChat
1383
 
1384
+ Chat with AI models from multiple providers: OpenRouter, OpenAI, HuggingFace, Groq, Cohere, and GLHF.
1385
  """)
1386
 
1387
  with gr.Row():
1388
  with gr.Column(scale=2):
1389
+ # Chatbot interface
1390
  chatbot = gr.Chatbot(
1391
  height=500,
1392
  show_copy_button=True,
1393
  show_label=False,
1394
  avatar_images=(None, "https://upload.wikimedia.org/wikipedia/commons/0/04/ChatGPT_logo.svg"),
1395
+ type="messages",
1396
+ elem_id="chat-window"
 
 
 
 
 
 
1397
  )
1398
 
1399
  with gr.Row():
 
1401
  placeholder="Type your message here...",
1402
  label="Message",
1403
  lines=2,
1404
+ elem_id="message-input",
1405
  scale=4
1406
  )
1407
 
 
1436
  )
1437
 
1438
  with gr.Column(scale=1):
1439
+ with gr.Group(elem_classes="provider-selection"):
1440
+ gr.Markdown("### Provider Selection")
1441
+
1442
+ # Provider selection
1443
+ provider_choice = gr.Radio(
1444
+ choices=["OpenRouter", "OpenAI", "HuggingFace", "Groq", "Cohere", "GLHF"],
1445
+ value="OpenRouter",
1446
+ label="AI Provider"
1447
+ )
1448
+
1449
+ # API key input
1450
+ api_key_override = gr.Textbox(
1451
+ placeholder="Override API key (leave empty to use environment variable)",
1452
+ label="API Key Override",
1453
+ type="password"
1454
+ )
1455
+
1456
  with gr.Group():
1457
  gr.Markdown("### Model Selection")
1458
 
 
1463
  show_label=False
1464
  )
1465
 
1466
+ # Provider-specific model dropdowns
1467
+ openrouter_model = gr.Dropdown(
1468
+ choices=[model[0] for model in OPENROUTER_ALL_MODELS],
1469
+ value=OPENROUTER_ALL_MODELS[0][0] if OPENROUTER_ALL_MODELS else None,
1470
+ label="OpenRouter Model",
1471
+ elem_id="openrouter-model-choice",
1472
+ visible=True
1473
+ )
 
 
 
 
 
 
 
 
 
1474
 
1475
+ openai_model = gr.Dropdown(
1476
+ choices=list(OPENAI_MODELS.keys()),
1477
+ value="gpt-3.5-turbo" if "gpt-3.5-turbo" in OPENAI_MODELS else None,
1478
+ label="OpenAI Model",
1479
+ elem_id="openai-model-choice",
1480
+ visible=False
1481
+ )
1482
+
1483
+ hf_model = gr.Dropdown(
1484
+ choices=list(HUGGINGFACE_MODELS.keys()),
1485
+ value="mistralai/Mistral-7B-Instruct-v0.3" if "mistralai/Mistral-7B-Instruct-v0.3" in HUGGINGFACE_MODELS else None,
1486
+ label="HuggingFace Model",
1487
+ elem_id="hf-model-choice",
1488
+ visible=False
1489
+ )
1490
+
1491
+ groq_model = gr.Dropdown(
1492
+ choices=list(GROQ_MODELS.keys()),
1493
+ value="llama-3.1-8b-instant" if "llama-3.1-8b-instant" in GROQ_MODELS else None,
1494
+ label="Groq Model",
1495
+ elem_id="groq-model-choice",
1496
+ visible=False
1497
+ )
1498
+
1499
+ cohere_model = gr.Dropdown(
1500
+ choices=list(COHERE_MODELS.keys()),
1501
+ value="command-r-plus" if "command-r-plus" in COHERE_MODELS else None,
1502
+ label="Cohere Model",
1503
+ elem_id="cohere-model-choice",
1504
+ visible=False
1505
+ )
1506
+
1507
+ glhf_model = gr.Dropdown(
1508
+ choices=list(GLHF_MODELS.keys()),
1509
+ value="mistralai/Mistral-7B-Instruct-v0.3" if "mistralai/Mistral-7B-Instruct-v0.3" in GLHF_MODELS else None,
1510
+ label="GLHF Model",
1511
+ elem_id="glhf-model-choice",
1512
+ visible=False
1513
+ )
1514
+
1515
+ context_display = gr.Textbox(
1516
+ value=update_context_display("OpenRouter", OPENROUTER_ALL_MODELS[0][0]),
1517
+ label="Context Size",
1518
+ interactive=False,
1519
+ elem_classes="context-size"
1520
+ )
1521
 
1522
  with gr.Accordion("Generation Parameters", open=False):
1523
  with gr.Group(elem_classes="parameter-grid"):
 
1564
  reasoning_effort = gr.Radio(
1565
  ["none", "low", "medium", "high"],
1566
  value="none",
1567
+ label="Reasoning Effort (OpenRouter)"
1568
  )
1569
 
1570
  with gr.Accordion("Advanced Options", open=False):
 
1623
 
1624
  gr.Markdown("""
1625
  * **json_object**: Forces the model to respond with valid JSON only.
1626
+ * Only available on certain models - check model support.
1627
  """)
1628
 
1629
  # Custom instructing options
 
1648
  # Add a model information section
1649
  with gr.Accordion("About Selected Model", open=False):
1650
  model_info_display = gr.HTML(
1651
+ value=update_model_info("OpenRouter", OPENROUTER_ALL_MODELS[0][0])
1652
  )
1653
 
1654
  # Add usage instructions
 
1656
  gr.Markdown("""
1657
  ## Basic Usage
1658
  1. Type your message in the input box
1659
+ 2. Select a provider and model
1660
  3. Click "Send" or press Enter
1661
 
1662
  ## Working with Files
1663
  - **Images**: Upload images to use with vision-capable models
1664
  - **Documents**: Upload PDF, Markdown, or text files to analyze their content
1665
 
1666
+ ## Provider Information
1667
+ - **OpenRouter**: Free access to various models with context window sizes up to 2M tokens
1668
+ - **OpenAI**: Requires an API key, includes GPT-3.5 and GPT-4 models
1669
+ - **HuggingFace**: Direct access to open models, some models require API key
1670
+ - **Groq**: High-performance inference, requires API key
1671
+ - **Cohere**: Specialized in language understanding, requires API key
1672
+ - **GLHF**: Access to HuggingFace models, requires API key
1673
+
1674
  ## Advanced Parameters
1675
  - **Temperature**: Controls randomness (higher = more creative, lower = more deterministic)
1676
  - **Max Tokens**: Maximum length of the response
1677
  - **Top P**: Nucleus sampling threshold (higher = consider more tokens)
1678
+ - **Reasoning Effort**: Some models can show their reasoning process (OpenRouter only)
 
 
 
 
 
1679
  """)
1680
 
1681
  # Add a footer with version info
1682
  footer_md = gr.Markdown("""
1683
  ---
1684
+ ### Multi-Provider CrispChat v1.0
1685
+ Built with ❤️ using Gradio and multiple AI provider APIs | Context sizes shown next to model names
1686
  """)
1687
 
1688
+ # Define event handlers
1689
+ def toggle_model_dropdowns(provider):
1690
+ """Show/hide model dropdowns based on provider selection"""
1691
+ return {
1692
+ openrouter_model: gr.update(visible=(provider == "OpenRouter")),
1693
+ openai_model: gr.update(visible=(provider == "OpenAI")),
1694
+ hf_model: gr.update(visible=(provider == "HuggingFace")),
1695
+ groq_model: gr.update(visible=(provider == "Groq")),
1696
+ cohere_model: gr.update(visible=(provider == "Cohere")),
1697
+ glhf_model: gr.update(visible=(provider == "GLHF"))
1698
+ }
1699
+
1700
+ def update_context_for_provider(provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model):
1701
+ """Update context display based on selected provider and model"""
1702
+ if provider == "OpenRouter":
1703
+ return update_context_display(provider, openrouter_model)
1704
+ elif provider == "OpenAI":
1705
+ return update_context_display(provider, openai_model)
1706
+ elif provider == "HuggingFace":
1707
+ return update_context_display(provider, hf_model)
1708
+ elif provider == "Groq":
1709
+ return update_context_display(provider, groq_model)
1710
+ elif provider == "Cohere":
1711
+ return update_context_display(provider, cohere_model)
1712
+ elif provider == "GLHF":
1713
+ return update_context_display(provider, glhf_model)
1714
+ return "Unknown"
1715
+
1716
+ def update_model_info_for_provider(provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model):
1717
+ """Update model info based on selected provider and model"""
1718
+ if provider == "OpenRouter":
1719
+ return update_model_info(provider, openrouter_model)
1720
+ elif provider == "OpenAI":
1721
+ return update_model_info(provider, openai_model)
1722
+ elif provider == "HuggingFace":
1723
+ return update_model_info(provider, hf_model)
1724
+ elif provider == "Groq":
1725
+ return update_model_info(provider, groq_model)
1726
+ elif provider == "Cohere":
1727
+ return update_model_info(provider, cohere_model)
1728
+ elif provider == "GLHF":
1729
+ return update_model_info(provider, glhf_model)
1730
+ return "<p>Model information not available</p>"
1731
+
1732
+ def filter_provider_models(provider, search_term):
1733
+ """Filter models for the selected provider"""
1734
+ if provider == "OpenRouter":
1735
+ all_models = [model[0] for model in OPENROUTER_ALL_MODELS]
1736
+ elif provider == "OpenAI":
1737
+ all_models = list(OPENAI_MODELS.keys())
1738
+ elif provider == "HuggingFace":
1739
+ all_models = list(HUGGINGFACE_MODELS.keys())
1740
+ elif provider == "Groq":
1741
+ all_models = list(GROQ_MODELS.keys())
1742
+ elif provider == "Cohere":
1743
+ all_models = list(COHERE_MODELS.keys())
1744
+ elif provider == "GLHF":
1745
+ all_models = list(GLHF_MODELS.keys())
1746
+ else:
1747
+ return [], None
1748
+
1749
+ if not search_term:
1750
+ return all_models, all_models[0] if all_models else None
1751
+
1752
+ filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
1753
+
1754
+ if filtered_models:
1755
+ return filtered_models, filtered_models[0]
1756
+ else:
1757
+ return all_models, all_models[0] if all_models else None
1758
+
1759
+ def refresh_groq_models_list():
1760
+ """Refresh the list of Groq models"""
1761
+ global GROQ_MODELS
1762
+ GROQ_MODELS = fetch_groq_models()
1763
+ return gr.update(choices=list(GROQ_MODELS.keys()))
1764
+
1765
+ def get_current_model(provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model):
1766
+ """Get the currently selected model based on provider"""
1767
+ if provider == "OpenRouter":
1768
+ return openrouter_model
1769
+ elif provider == "OpenAI":
1770
+ return openai_model
1771
+ elif provider == "HuggingFace":
1772
+ return hf_model
1773
+ elif provider == "Groq":
1774
+ return groq_model
1775
+ elif provider == "Cohere":
1776
+ return cohere_model
1777
+ elif provider == "GLHF":
1778
+ return glhf_model
1779
+ return None
1780
 
1781
+ # Process uploaded images
1782
+ image_upload_btn.upload(
1783
+ fn=lambda files: files,
1784
+ inputs=image_upload_btn,
1785
+ outputs=images
1786
+ )
1787
+
1788
+ # Set up provider selection event
1789
+ provider_choice.change(
1790
+ fn=toggle_model_dropdowns,
1791
+ inputs=provider_choice,
1792
+ outputs=[openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model]
1793
+ ).then(
1794
+ fn=update_context_for_provider,
1795
+ inputs=[provider_choice, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model],
1796
+ outputs=context_display
1797
+ ).then(
1798
+ fn=update_model_info_for_provider,
1799
+ inputs=[provider_choice, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model],
1800
+ outputs=model_info_display
1801
+ )
1802
+
1803
+ # Set up model search event
1804
  model_search.change(
1805
+ fn=lambda provider, search: filter_provider_models(provider, search),
1806
+ inputs=[provider_choice, model_search],
1807
+ outputs=[
1808
+ gr.update(choices=None, value=None),
1809
+ gr.update(choices=None, value=None)
1810
+ ]
1811
  )
1812
 
1813
+ # Set up model change events
1814
+ openrouter_model.change(
1815
+ fn=lambda model: update_context_display("OpenRouter", model),
1816
+ inputs=openrouter_model,
1817
  outputs=context_display
1818
+ ).then(
1819
+ fn=lambda model: update_model_info("OpenRouter", model),
1820
+ inputs=openrouter_model,
1821
+ outputs=model_info_display
1822
  )
1823
 
1824
+ openai_model.change(
1825
+ fn=lambda model: update_context_display("OpenAI", model),
1826
+ inputs=openai_model,
1827
+ outputs=context_display
1828
+ ).then(
1829
+ fn=lambda model: update_model_info("OpenAI", model),
1830
+ inputs=openai_model,
1831
  outputs=model_info_display
1832
  )
1833
 
1834
+ hf_model.change(
1835
+ fn=lambda model: update_context_display("HuggingFace", model),
1836
+ inputs=hf_model,
1837
+ outputs=context_display
1838
+ ).then(
1839
+ fn=lambda model: update_model_info("HuggingFace", model),
1840
+ inputs=hf_model,
1841
+ outputs=model_info_display
1842
+ )
1843
 
1844
+ groq_model.change(
1845
+ fn=lambda model: update_context_display("Groq", model),
1846
+ inputs=groq_model,
1847
+ outputs=context_display
1848
+ ).then(
1849
+ fn=lambda model: update_model_info("Groq", model),
1850
+ inputs=groq_model,
1851
+ outputs=model_info_display
1852
  )
1853
+
1854
+ cohere_model.change(
1855
+ fn=lambda model: update_context_display("Cohere", model),
1856
+ inputs=cohere_model,
1857
+ outputs=context_display
1858
+ ).then(
1859
+ fn=lambda model: update_model_info("Cohere", model),
1860
+ inputs=cohere_model,
1861
+ outputs=model_info_display
1862
+ )
1863
+
1864
+ glhf_model.change(
1865
+ fn=lambda model: update_context_display("GLHF", model),
1866
+ inputs=glhf_model,
1867
+ outputs=context_display
1868
+ ).then(
1869
+ fn=lambda model: update_model_info("GLHF", model),
1870
+ inputs=glhf_model,
1871
+ outputs=model_info_display
1872
  )
1873
 
1874
+ # Set up submission event
1875
+ def submit_message(message, history, provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model,
1876
+ temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty,
1877
+ top_k, min_p, seed, top_a, stream_output, response_format,
1878
+ images, documents, reasoning_effort, system_message, transforms, api_key_override):
1879
+ """Submit message to selected provider and model"""
1880
+ # Get the currently selected model
1881
+ model_choice = get_current_model(provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model)
1882
+
1883
+ # Check if model is selected
1884
+ if not model_choice:
1885
+ history.append([message, f"Error: No model selected for provider {provider}"])
1886
+ return history
1887
+
1888
+ # Call the ask_ai function with the appropriate parameters
1889
+ return ask_ai(
1890
+ message=message,
1891
+ history=history,
1892
+ provider=provider,
1893
+ model_choice=model_choice,
1894
+ temperature=temperature,
1895
+ max_tokens=max_tokens,
1896
+ top_p=top_p,
1897
+ frequency_penalty=frequency_penalty,
1898
+ presence_penalty=presence_penalty,
1899
+ repetition_penalty=repetition_penalty,
1900
+ top_k=top_k,
1901
+ min_p=min_p,
1902
+ seed=seed,
1903
+ top_a=top_a,
1904
+ stream_output=stream_output,
1905
+ response_format=response_format,
1906
+ images=images,
1907
+ documents=documents,
1908
+ reasoning_effort=reasoning_effort,
1909
+ system_message=system_message,
1910
+ transforms=transforms,
1911
+ api_key_override=api_key_override
1912
+ )
1913
+
1914
+ # Submit button click event
1915
  submit_btn.click(
1916
+ fn=submit_message,
1917
  inputs=[
1918
+ message, chatbot, provider_choice,
1919
+ openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model,
1920
+ temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty,
1921
  top_k, min_p, seed, top_a, stream_output, response_format,
1922
+ images, documents, reasoning_effort, system_message, transforms, api_key_override
1923
  ],
1924
  outputs=chatbot,
1925
  show_progress="minimal",
 
1929
  outputs=message
1930
  )
1931
 
1932
+ # Also submit on Enter key
1933
  message.submit(
1934
+ fn=submit_message,
1935
  inputs=[
1936
+ message, chatbot, provider_choice,
1937
+ openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model,
1938
+ temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty,
1939
  top_k, min_p, seed, top_a, stream_output, response_format,
1940
+ images, documents, reasoning_effort, system_message, transforms, api_key_override
1941
  ],
1942
  outputs=chatbot,
1943
  show_progress="minimal",
 
1947
  outputs=message
1948
  )
1949
 
1950
+ # Clear chat button
1951
  clear_btn.click(
1952
  fn=clear_chat,
1953
  inputs=[],
 
1959
  ]
1960
  )
1961
 
 
 
 
 
 
 
 
 
 
 
 
1962
  return demo
1963
 
 
 
 
1964
  # Launch the app
1965
  if __name__ == "__main__":
1966
+ # Check API keys before starting
1967
  if not OPENROUTER_API_KEY:
1968
  logger.warning("WARNING: OPENROUTER_API_KEY environment variable is not set")
1969
+ print("WARNING: OpenRouter API key not found. Set OPENROUTER_API_KEY environment variable to access free models.")
1970
 
1971
  demo = create_app()
1972
  demo.launch(