Update app.py
Browse files
app.py
CHANGED
@@ -639,8 +639,12 @@ def is_vision_model(provider, model_name):
|
|
639 |
return True
|
640 |
|
641 |
# Also check for common vision indicators in model names
|
642 |
-
|
643 |
-
|
|
|
|
|
|
|
|
|
644 |
|
645 |
return False
|
646 |
|
@@ -806,21 +810,28 @@ def call_groq_api(payload, api_key_override=None):
|
|
806 |
|
807 |
# Extract parameters from payload
|
808 |
model = payload.get("model", "llama-3.1-8b-instant")
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
814 |
|
815 |
# Create completion
|
816 |
-
response = client.chat.completions.create(
|
817 |
-
model=model,
|
818 |
-
messages=messages,
|
819 |
-
temperature=temperature,
|
820 |
-
max_tokens=max_tokens,
|
821 |
-
stream=stream,
|
822 |
-
top_p=top_p
|
823 |
-
)
|
824 |
|
825 |
return response
|
826 |
except Exception as e:
|
@@ -837,7 +848,7 @@ def call_cohere_api(payload, api_key_override=None):
|
|
837 |
if not api_key:
|
838 |
raise ValueError("Cohere API key is required")
|
839 |
|
840 |
-
client = cohere.
|
841 |
|
842 |
# Extract parameters from payload
|
843 |
model = payload.get("model", "command-r-plus")
|
@@ -845,40 +856,27 @@ def call_cohere_api(payload, api_key_override=None):
|
|
845 |
temperature = payload.get("temperature", 0.7)
|
846 |
max_tokens = payload.get("max_tokens", 1000)
|
847 |
|
848 |
-
#
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
for msg in messages:
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
|
857 |
-
|
858 |
-
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
|
865 |
-
# Handle multimodal content
|
866 |
-
if isinstance(content, list):
|
867 |
-
text_parts = []
|
868 |
-
for item in content:
|
869 |
-
if item["type"] == "text":
|
870 |
-
text_parts.append(item["text"])
|
871 |
-
content = "\n".join(text_parts)
|
872 |
-
user_message = content
|
873 |
-
elif msg["role"] == "assistant":
|
874 |
-
content = msg["content"]
|
875 |
-
if content:
|
876 |
-
chat_history.append({"role": "ASSISTANT", "message": content})
|
877 |
|
878 |
# Create chat completion
|
879 |
response = client.chat(
|
880 |
-
message=
|
881 |
-
chat_history=
|
882 |
model=model,
|
883 |
temperature=temperature,
|
884 |
max_tokens=max_tokens
|
@@ -898,7 +896,8 @@ def call_together_api(payload, api_key_override=None):
|
|
898 |
api_key = api_key_override if api_key_override else TOGETHER_API_KEY
|
899 |
if not api_key:
|
900 |
raise ValueError("Together API key is required")
|
901 |
-
|
|
|
902 |
client = openai.OpenAI(
|
903 |
api_key=api_key,
|
904 |
base_url="https://api.together.xyz/v1"
|
@@ -906,19 +905,34 @@ def call_together_api(payload, api_key_override=None):
|
|
906 |
|
907 |
# Extract parameters from payload
|
908 |
model = payload.get("model", "meta-llama/Llama-3.1-8B-Instruct")
|
909 |
-
|
910 |
-
|
911 |
-
|
912 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
913 |
|
914 |
# Create completion
|
915 |
-
response = client.chat.completions.create(
|
916 |
-
model=model,
|
917 |
-
messages=messages,
|
918 |
-
temperature=temperature,
|
919 |
-
max_tokens=max_tokens,
|
920 |
-
stream=stream
|
921 |
-
)
|
922 |
|
923 |
return response
|
924 |
except Exception as e:
|
@@ -928,7 +942,7 @@ def call_together_api(payload, api_key_override=None):
|
|
928 |
def call_ovh_api(payload, api_key_override=None):
|
929 |
"""Make a call to OVH AI Endpoints API with error handling"""
|
930 |
try:
|
931 |
-
#
|
932 |
model = payload.get("model", "ovh/llama-3.1-8b-instruct")
|
933 |
messages = payload.get("messages", [])
|
934 |
temperature = payload.get("temperature", 0.7)
|
@@ -938,15 +952,25 @@ def call_ovh_api(payload, api_key_override=None):
|
|
938 |
"Content-Type": "application/json"
|
939 |
}
|
940 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
941 |
data = {
|
942 |
"model": model,
|
943 |
-
"messages":
|
944 |
"temperature": temperature,
|
945 |
"max_tokens": max_tokens
|
946 |
}
|
947 |
|
|
|
948 |
response = requests.post(
|
949 |
-
"https://
|
950 |
headers=headers,
|
951 |
json=data
|
952 |
)
|
@@ -962,21 +986,28 @@ def call_ovh_api(payload, api_key_override=None):
|
|
962 |
def call_cerebras_api(payload, api_key_override=None):
|
963 |
"""Make a call to Cerebras API with error handling"""
|
964 |
try:
|
965 |
-
#
|
966 |
model = payload.get("model", "cerebras/llama-3.1-8b")
|
967 |
-
messages = payload.get("messages", [])
|
968 |
-
temperature = payload.get("temperature", 0.7)
|
969 |
-
max_tokens = payload.get("max_tokens", 1000)
|
970 |
-
|
971 |
-
headers = {
|
972 |
-
"Content-Type": "application/json"
|
973 |
-
}
|
974 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
975 |
data = {
|
976 |
"model": model,
|
977 |
"messages": messages,
|
978 |
-
"temperature": temperature,
|
979 |
-
"max_tokens": max_tokens
|
|
|
|
|
|
|
|
|
|
|
980 |
}
|
981 |
|
982 |
response = requests.post(
|
|
|
639 |
return True
|
640 |
|
641 |
# Also check for common vision indicators in model names
|
642 |
+
try:
|
643 |
+
if any(x in model_name.lower() for x in ["vl", "vision", "visual", "llava", "gemini"]):
|
644 |
+
return True
|
645 |
+
except AttributeError:
|
646 |
+
# In case model_name is not a string or has no lower method
|
647 |
+
return False
|
648 |
|
649 |
return False
|
650 |
|
|
|
810 |
|
811 |
# Extract parameters from payload
|
812 |
model = payload.get("model", "llama-3.1-8b-instant")
|
813 |
+
|
814 |
+
# Clean up messages - remove any unexpected properties
|
815 |
+
messages = []
|
816 |
+
for msg in payload.get("messages", []):
|
817 |
+
clean_msg = {
|
818 |
+
"role": msg["role"],
|
819 |
+
"content": msg["content"]
|
820 |
+
}
|
821 |
+
messages.append(clean_msg)
|
822 |
+
|
823 |
+
# Basic parameters
|
824 |
+
groq_payload = {
|
825 |
+
"model": model,
|
826 |
+
"messages": messages,
|
827 |
+
"temperature": payload.get("temperature", 0.7),
|
828 |
+
"max_tokens": payload.get("max_tokens", 1000),
|
829 |
+
"stream": payload.get("stream", False),
|
830 |
+
"top_p": payload.get("top_p", 0.9)
|
831 |
+
}
|
832 |
|
833 |
# Create completion
|
834 |
+
response = client.chat.completions.create(**groq_payload)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
835 |
|
836 |
return response
|
837 |
except Exception as e:
|
|
|
848 |
if not api_key:
|
849 |
raise ValueError("Cohere API key is required")
|
850 |
|
851 |
+
client = cohere.ClientV2(api_key=api_key)
|
852 |
|
853 |
# Extract parameters from payload
|
854 |
model = payload.get("model", "command-r-plus")
|
|
|
856 |
temperature = payload.get("temperature", 0.7)
|
857 |
max_tokens = payload.get("max_tokens", 1000)
|
858 |
|
859 |
+
# Transform messages to Cohere format - IMPORTANT
|
860 |
+
# Cohere uses specific role names: USER, ASSISTANT, SYSTEM, TOOL
|
861 |
+
cohere_messages = []
|
|
|
862 |
for msg in messages:
|
863 |
+
role = msg["role"].upper() # Cohere requires uppercase roles
|
864 |
+
content = msg["content"]
|
865 |
+
|
866 |
+
# Handle multimodal content
|
867 |
+
if isinstance(content, list):
|
868 |
+
text_parts = []
|
869 |
+
for item in content:
|
870 |
+
if item["type"] == "text":
|
871 |
+
text_parts.append(item["text"])
|
872 |
+
content = "\n".join(text_parts)
|
873 |
+
|
874 |
+
cohere_messages.append({"role": role, "content": content})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
875 |
|
876 |
# Create chat completion
|
877 |
response = client.chat(
|
878 |
+
message=cohere_messages[-1]["content"] if cohere_messages else "",
|
879 |
+
chat_history=cohere_messages[:-1] if len(cohere_messages) > 1 else [],
|
880 |
model=model,
|
881 |
temperature=temperature,
|
882 |
max_tokens=max_tokens
|
|
|
896 |
api_key = api_key_override if api_key_override else TOGETHER_API_KEY
|
897 |
if not api_key:
|
898 |
raise ValueError("Together API key is required")
|
899 |
+
|
900 |
+
# Create client with Together base URL
|
901 |
client = openai.OpenAI(
|
902 |
api_key=api_key,
|
903 |
base_url="https://api.together.xyz/v1"
|
|
|
905 |
|
906 |
# Extract parameters from payload
|
907 |
model = payload.get("model", "meta-llama/Llama-3.1-8B-Instruct")
|
908 |
+
|
909 |
+
# Fix model name format - Together API expects this format
|
910 |
+
if not model.startswith("meta-llama/") and "llama" in model.lower():
|
911 |
+
# Convert model ID format from "llama-3.1-8b-instruct" to "meta-llama/Llama-3.1-8B-Instruct"
|
912 |
+
parts = model.split("-")
|
913 |
+
formatted_name = "meta-llama/L" + "".join([p.capitalize() for p in parts])
|
914 |
+
model = formatted_name
|
915 |
+
|
916 |
+
# Clean up messages - remove any unexpected properties
|
917 |
+
messages = []
|
918 |
+
for msg in payload.get("messages", []):
|
919 |
+
clean_msg = {
|
920 |
+
"role": msg["role"],
|
921 |
+
"content": msg["content"]
|
922 |
+
}
|
923 |
+
messages.append(clean_msg)
|
924 |
+
|
925 |
+
# Create payload
|
926 |
+
together_payload = {
|
927 |
+
"model": model,
|
928 |
+
"messages": messages,
|
929 |
+
"temperature": payload.get("temperature", 0.7),
|
930 |
+
"max_tokens": payload.get("max_tokens", 1000),
|
931 |
+
"stream": payload.get("stream", False)
|
932 |
+
}
|
933 |
|
934 |
# Create completion
|
935 |
+
response = client.chat.completions.create(**together_payload)
|
|
|
|
|
|
|
|
|
|
|
|
|
936 |
|
937 |
return response
|
938 |
except Exception as e:
|
|
|
942 |
def call_ovh_api(payload, api_key_override=None):
|
943 |
"""Make a call to OVH AI Endpoints API with error handling"""
|
944 |
try:
|
945 |
+
# Extract parameters from payload
|
946 |
model = payload.get("model", "ovh/llama-3.1-8b-instruct")
|
947 |
messages = payload.get("messages", [])
|
948 |
temperature = payload.get("temperature", 0.7)
|
|
|
952 |
"Content-Type": "application/json"
|
953 |
}
|
954 |
|
955 |
+
# Clean up messages - remove any unexpected properties
|
956 |
+
clean_messages = []
|
957 |
+
for msg in messages:
|
958 |
+
clean_msg = {
|
959 |
+
"role": msg["role"],
|
960 |
+
"content": msg["content"]
|
961 |
+
}
|
962 |
+
clean_messages.append(clean_msg)
|
963 |
+
|
964 |
data = {
|
965 |
"model": model,
|
966 |
+
"messages": clean_messages,
|
967 |
"temperature": temperature,
|
968 |
"max_tokens": max_tokens
|
969 |
}
|
970 |
|
971 |
+
# Updated endpoint with correct path
|
972 |
response = requests.post(
|
973 |
+
"https://api.ai.cloud.ovh.net/v1/chat/completions",
|
974 |
headers=headers,
|
975 |
json=data
|
976 |
)
|
|
|
986 |
def call_cerebras_api(payload, api_key_override=None):
|
987 |
"""Make a call to Cerebras API with error handling"""
|
988 |
try:
|
989 |
+
# Extract parameters from payload
|
990 |
model = payload.get("model", "cerebras/llama-3.1-8b")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
991 |
|
992 |
+
# Clean up messages - remove any unexpected properties
|
993 |
+
messages = []
|
994 |
+
for msg in payload.get("messages", []):
|
995 |
+
clean_msg = {
|
996 |
+
"role": msg["role"],
|
997 |
+
"content": msg["content"]
|
998 |
+
}
|
999 |
+
messages.append(clean_msg)
|
1000 |
+
|
1001 |
data = {
|
1002 |
"model": model,
|
1003 |
"messages": messages,
|
1004 |
+
"temperature": payload.get("temperature", 0.7),
|
1005 |
+
"max_tokens": payload.get("max_tokens", 1000)
|
1006 |
+
}
|
1007 |
+
|
1008 |
+
headers = {
|
1009 |
+
"Content-Type": "application/json",
|
1010 |
+
"Authorization": f"Bearer {api_key_override or os.environ.get('CEREBRAS_API_KEY', '')}"
|
1011 |
}
|
1012 |
|
1013 |
response = requests.post(
|