cstr commited on
Commit
7d23974
·
verified ·
1 Parent(s): 0b1b904

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -71
app.py CHANGED
@@ -639,8 +639,12 @@ def is_vision_model(provider, model_name):
639
  return True
640
 
641
  # Also check for common vision indicators in model names
642
- if any(x in model_name.lower() for x in ["vl", "vision", "visual", "llava", "gemini"]):
643
- return True
 
 
 
 
644
 
645
  return False
646
 
@@ -806,21 +810,28 @@ def call_groq_api(payload, api_key_override=None):
806
 
807
  # Extract parameters from payload
808
  model = payload.get("model", "llama-3.1-8b-instant")
809
- messages = payload.get("messages", [])
810
- temperature = payload.get("temperature", 0.7)
811
- max_tokens = payload.get("max_tokens", 1000)
812
- stream = payload.get("stream", False)
813
- top_p = payload.get("top_p", 0.9)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
814
 
815
  # Create completion
816
- response = client.chat.completions.create(
817
- model=model,
818
- messages=messages,
819
- temperature=temperature,
820
- max_tokens=max_tokens,
821
- stream=stream,
822
- top_p=top_p
823
- )
824
 
825
  return response
826
  except Exception as e:
@@ -837,7 +848,7 @@ def call_cohere_api(payload, api_key_override=None):
837
  if not api_key:
838
  raise ValueError("Cohere API key is required")
839
 
840
- client = cohere.Client(api_key=api_key)
841
 
842
  # Extract parameters from payload
843
  model = payload.get("model", "command-r-plus")
@@ -845,40 +856,27 @@ def call_cohere_api(payload, api_key_override=None):
845
  temperature = payload.get("temperature", 0.7)
846
  max_tokens = payload.get("max_tokens", 1000)
847
 
848
- # Format messages for Cohere
849
- chat_history = []
850
- user_message = ""
851
-
852
  for msg in messages:
853
- if msg["role"] == "system":
854
- # For system message, we'll prepend to the user's first message
855
- system_content = msg["content"]
856
- if isinstance(system_content, list): # Handle multimodal content
857
- system_parts = []
858
- for item in system_content:
859
- if item["type"] == "text":
860
- system_parts.append(item["text"])
861
- system_content = "\n".join(system_parts)
862
- user_message = f"System: {system_content}\n\n" + user_message
863
- elif msg["role"] == "user":
864
- content = msg["content"]
865
- # Handle multimodal content
866
- if isinstance(content, list):
867
- text_parts = []
868
- for item in content:
869
- if item["type"] == "text":
870
- text_parts.append(item["text"])
871
- content = "\n".join(text_parts)
872
- user_message = content
873
- elif msg["role"] == "assistant":
874
- content = msg["content"]
875
- if content:
876
- chat_history.append({"role": "ASSISTANT", "message": content})
877
 
878
  # Create chat completion
879
  response = client.chat(
880
- message=user_message,
881
- chat_history=chat_history,
882
  model=model,
883
  temperature=temperature,
884
  max_tokens=max_tokens
@@ -898,7 +896,8 @@ def call_together_api(payload, api_key_override=None):
898
  api_key = api_key_override if api_key_override else TOGETHER_API_KEY
899
  if not api_key:
900
  raise ValueError("Together API key is required")
901
-
 
902
  client = openai.OpenAI(
903
  api_key=api_key,
904
  base_url="https://api.together.xyz/v1"
@@ -906,19 +905,34 @@ def call_together_api(payload, api_key_override=None):
906
 
907
  # Extract parameters from payload
908
  model = payload.get("model", "meta-llama/Llama-3.1-8B-Instruct")
909
- messages = payload.get("messages", [])
910
- temperature = payload.get("temperature", 0.7)
911
- max_tokens = payload.get("max_tokens", 1000)
912
- stream = payload.get("stream", False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
913
 
914
  # Create completion
915
- response = client.chat.completions.create(
916
- model=model,
917
- messages=messages,
918
- temperature=temperature,
919
- max_tokens=max_tokens,
920
- stream=stream
921
- )
922
 
923
  return response
924
  except Exception as e:
@@ -928,7 +942,7 @@ def call_together_api(payload, api_key_override=None):
928
  def call_ovh_api(payload, api_key_override=None):
929
  """Make a call to OVH AI Endpoints API with error handling"""
930
  try:
931
- # Use custom OpenAI client with the OVH endpoint
932
  model = payload.get("model", "ovh/llama-3.1-8b-instruct")
933
  messages = payload.get("messages", [])
934
  temperature = payload.get("temperature", 0.7)
@@ -938,15 +952,25 @@ def call_ovh_api(payload, api_key_override=None):
938
  "Content-Type": "application/json"
939
  }
940
 
 
 
 
 
 
 
 
 
 
941
  data = {
942
  "model": model,
943
- "messages": messages,
944
  "temperature": temperature,
945
  "max_tokens": max_tokens
946
  }
947
 
 
948
  response = requests.post(
949
- "https://endpoints.ai.cloud.ovh.net/v1/chat/completions",
950
  headers=headers,
951
  json=data
952
  )
@@ -962,21 +986,28 @@ def call_ovh_api(payload, api_key_override=None):
962
  def call_cerebras_api(payload, api_key_override=None):
963
  """Make a call to Cerebras API with error handling"""
964
  try:
965
- # Use vanilla requests for this API
966
  model = payload.get("model", "cerebras/llama-3.1-8b")
967
- messages = payload.get("messages", [])
968
- temperature = payload.get("temperature", 0.7)
969
- max_tokens = payload.get("max_tokens", 1000)
970
-
971
- headers = {
972
- "Content-Type": "application/json"
973
- }
974
 
 
 
 
 
 
 
 
 
 
975
  data = {
976
  "model": model,
977
  "messages": messages,
978
- "temperature": temperature,
979
- "max_tokens": max_tokens
 
 
 
 
 
980
  }
981
 
982
  response = requests.post(
 
639
  return True
640
 
641
  # Also check for common vision indicators in model names
642
+ try:
643
+ if any(x in model_name.lower() for x in ["vl", "vision", "visual", "llava", "gemini"]):
644
+ return True
645
+ except AttributeError:
646
+ # In case model_name is not a string or has no lower method
647
+ return False
648
 
649
  return False
650
 
 
810
 
811
  # Extract parameters from payload
812
  model = payload.get("model", "llama-3.1-8b-instant")
813
+
814
+ # Clean up messages - remove any unexpected properties
815
+ messages = []
816
+ for msg in payload.get("messages", []):
817
+ clean_msg = {
818
+ "role": msg["role"],
819
+ "content": msg["content"]
820
+ }
821
+ messages.append(clean_msg)
822
+
823
+ # Basic parameters
824
+ groq_payload = {
825
+ "model": model,
826
+ "messages": messages,
827
+ "temperature": payload.get("temperature", 0.7),
828
+ "max_tokens": payload.get("max_tokens", 1000),
829
+ "stream": payload.get("stream", False),
830
+ "top_p": payload.get("top_p", 0.9)
831
+ }
832
 
833
  # Create completion
834
+ response = client.chat.completions.create(**groq_payload)
 
 
 
 
 
 
 
835
 
836
  return response
837
  except Exception as e:
 
848
  if not api_key:
849
  raise ValueError("Cohere API key is required")
850
 
851
+ client = cohere.ClientV2(api_key=api_key)
852
 
853
  # Extract parameters from payload
854
  model = payload.get("model", "command-r-plus")
 
856
  temperature = payload.get("temperature", 0.7)
857
  max_tokens = payload.get("max_tokens", 1000)
858
 
859
+ # Transform messages to Cohere format - IMPORTANT
860
+ # Cohere uses specific role names: USER, ASSISTANT, SYSTEM, TOOL
861
+ cohere_messages = []
 
862
  for msg in messages:
863
+ role = msg["role"].upper() # Cohere requires uppercase roles
864
+ content = msg["content"]
865
+
866
+ # Handle multimodal content
867
+ if isinstance(content, list):
868
+ text_parts = []
869
+ for item in content:
870
+ if item["type"] == "text":
871
+ text_parts.append(item["text"])
872
+ content = "\n".join(text_parts)
873
+
874
+ cohere_messages.append({"role": role, "content": content})
 
 
 
 
 
 
 
 
 
 
 
 
875
 
876
  # Create chat completion
877
  response = client.chat(
878
+ message=cohere_messages[-1]["content"] if cohere_messages else "",
879
+ chat_history=cohere_messages[:-1] if len(cohere_messages) > 1 else [],
880
  model=model,
881
  temperature=temperature,
882
  max_tokens=max_tokens
 
896
  api_key = api_key_override if api_key_override else TOGETHER_API_KEY
897
  if not api_key:
898
  raise ValueError("Together API key is required")
899
+
900
+ # Create client with Together base URL
901
  client = openai.OpenAI(
902
  api_key=api_key,
903
  base_url="https://api.together.xyz/v1"
 
905
 
906
  # Extract parameters from payload
907
  model = payload.get("model", "meta-llama/Llama-3.1-8B-Instruct")
908
+
909
+ # Fix model name format - Together API expects this format
910
+ if not model.startswith("meta-llama/") and "llama" in model.lower():
911
+ # Convert model ID format from "llama-3.1-8b-instruct" to "meta-llama/Llama-3.1-8B-Instruct"
912
+ parts = model.split("-")
913
+ formatted_name = "meta-llama/L" + "".join([p.capitalize() for p in parts])
914
+ model = formatted_name
915
+
916
+ # Clean up messages - remove any unexpected properties
917
+ messages = []
918
+ for msg in payload.get("messages", []):
919
+ clean_msg = {
920
+ "role": msg["role"],
921
+ "content": msg["content"]
922
+ }
923
+ messages.append(clean_msg)
924
+
925
+ # Create payload
926
+ together_payload = {
927
+ "model": model,
928
+ "messages": messages,
929
+ "temperature": payload.get("temperature", 0.7),
930
+ "max_tokens": payload.get("max_tokens", 1000),
931
+ "stream": payload.get("stream", False)
932
+ }
933
 
934
  # Create completion
935
+ response = client.chat.completions.create(**together_payload)
 
 
 
 
 
 
936
 
937
  return response
938
  except Exception as e:
 
942
  def call_ovh_api(payload, api_key_override=None):
943
  """Make a call to OVH AI Endpoints API with error handling"""
944
  try:
945
+ # Extract parameters from payload
946
  model = payload.get("model", "ovh/llama-3.1-8b-instruct")
947
  messages = payload.get("messages", [])
948
  temperature = payload.get("temperature", 0.7)
 
952
  "Content-Type": "application/json"
953
  }
954
 
955
+ # Clean up messages - remove any unexpected properties
956
+ clean_messages = []
957
+ for msg in messages:
958
+ clean_msg = {
959
+ "role": msg["role"],
960
+ "content": msg["content"]
961
+ }
962
+ clean_messages.append(clean_msg)
963
+
964
  data = {
965
  "model": model,
966
+ "messages": clean_messages,
967
  "temperature": temperature,
968
  "max_tokens": max_tokens
969
  }
970
 
971
+ # Updated endpoint with correct path
972
  response = requests.post(
973
+ "https://api.ai.cloud.ovh.net/v1/chat/completions",
974
  headers=headers,
975
  json=data
976
  )
 
986
  def call_cerebras_api(payload, api_key_override=None):
987
  """Make a call to Cerebras API with error handling"""
988
  try:
989
+ # Extract parameters from payload
990
  model = payload.get("model", "cerebras/llama-3.1-8b")
 
 
 
 
 
 
 
991
 
992
+ # Clean up messages - remove any unexpected properties
993
+ messages = []
994
+ for msg in payload.get("messages", []):
995
+ clean_msg = {
996
+ "role": msg["role"],
997
+ "content": msg["content"]
998
+ }
999
+ messages.append(clean_msg)
1000
+
1001
  data = {
1002
  "model": model,
1003
  "messages": messages,
1004
+ "temperature": payload.get("temperature", 0.7),
1005
+ "max_tokens": payload.get("max_tokens", 1000)
1006
+ }
1007
+
1008
+ headers = {
1009
+ "Content-Type": "application/json",
1010
+ "Authorization": f"Bearer {api_key_override or os.environ.get('CEREBRAS_API_KEY', '')}"
1011
  }
1012
 
1013
  response = requests.post(