Spaces:
Running
Running
Eddie Pick
commited on
Commit
·
bb1d601
1
Parent(s):
9438062
Updates
Browse files- .gitignore +1 -0
- models.py +107 -63
.gitignore
CHANGED
@@ -121,6 +121,7 @@ celerybeat.pid
|
|
121 |
|
122 |
# Environments
|
123 |
.env
|
|
|
124 |
.venv
|
125 |
env/
|
126 |
venv/
|
|
|
121 |
|
122 |
# Environments
|
123 |
.env
|
124 |
+
.env.local
|
125 |
.venv
|
126 |
env/
|
127 |
venv/
|
models.py
CHANGED
@@ -1,11 +1,5 @@
|
|
1 |
import os
|
2 |
-
import
|
3 |
-
from langchain.schema import SystemMessage, HumanMessage
|
4 |
-
from langchain.prompts.chat import (
|
5 |
-
HumanMessagePromptTemplate,
|
6 |
-
SystemMessagePromptTemplate,
|
7 |
-
ChatPromptTemplate
|
8 |
-
)
|
9 |
from langchain.prompts.prompt import PromptTemplate
|
10 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
11 |
|
@@ -17,83 +11,122 @@ from langchain_fireworks.embeddings import FireworksEmbeddings
|
|
17 |
from langchain_groq.chat_models import ChatGroq
|
18 |
from langchain_openai import ChatOpenAI
|
19 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
|
|
|
|
|
|
20 |
from langchain_ollama.chat_models import ChatOllama
|
21 |
from langchain_ollama.embeddings import OllamaEmbeddings
|
22 |
from langchain_cohere.embeddings import CohereEmbeddings
|
23 |
from langchain_cohere.chat_models import ChatCohere
|
|
|
|
|
24 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
25 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
26 |
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
27 |
from langchain_community.chat_models import ChatPerplexity
|
28 |
from langchain_together import ChatTogether
|
29 |
from langchain_together.embeddings import TogetherEmbeddings
|
|
|
|
|
30 |
|
31 |
-
def split_provider_model(provider_model):
|
32 |
-
parts = provider_model.split(
|
33 |
provider = parts[0]
|
34 |
-
|
|
|
|
|
|
|
35 |
return provider, model
|
36 |
|
37 |
-
def get_model(provider_model, temperature=0.
|
|
|
|
|
|
|
|
|
38 |
provider, model = split_provider_model(provider_model)
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
model
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
model
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
model
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
model
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
model
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
model
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
model =
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
return chat_llm
|
85 |
|
86 |
|
87 |
-
def get_embedding_model(provider_model):
|
88 |
provider, model = split_provider_model(provider_model)
|
89 |
-
match provider:
|
90 |
case 'bedrock':
|
91 |
if model is None:
|
92 |
model = "amazon.titan-embed-text-v2:0"
|
93 |
embedding_model = BedrockEmbeddings(model_id=model)
|
94 |
case 'cohere':
|
95 |
if model is None:
|
96 |
-
model = "embed-
|
97 |
embedding_model = CohereEmbeddings(model=model)
|
98 |
case 'fireworks':
|
99 |
if model is None:
|
@@ -113,6 +146,14 @@ def get_embedding_model(provider_model):
|
|
113 |
embedding_model = GoogleGenerativeAIEmbeddings(model=model)
|
114 |
case 'groq':
|
115 |
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
case 'perplexity':
|
117 |
raise ValueError(f"Cannot use Perplexity for embedding model")
|
118 |
case 'together':
|
@@ -193,12 +234,15 @@ from models import get_model # Make sure this import is correct
|
|
193 |
class TestGetModel(unittest.TestCase):
|
194 |
|
195 |
@patch('models.ChatBedrockConverse')
|
196 |
-
def
|
197 |
result = get_model('bedrock')
|
198 |
-
mock_bedrock.assert_called_once_with(
|
199 |
-
|
200 |
-
|
201 |
-
|
|
|
|
|
|
|
202 |
self.assertEqual(result, mock_bedrock.return_value)
|
203 |
|
204 |
@patch('models.ChatCohere')
|
|
|
1 |
import os
|
2 |
+
from typing import Tuple, Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from langchain.prompts.prompt import PromptTemplate
|
4 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
5 |
|
|
|
11 |
from langchain_groq.chat_models import ChatGroq
|
12 |
from langchain_openai import ChatOpenAI
|
13 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
14 |
+
from langchain_anthropic.chat_models import ChatAnthropic
|
15 |
+
from langchain_mistralai.chat_models import ChatMistralAI
|
16 |
+
from langchain_mistralai.embeddings import MistralAIEmbeddings
|
17 |
from langchain_ollama.chat_models import ChatOllama
|
18 |
from langchain_ollama.embeddings import OllamaEmbeddings
|
19 |
from langchain_cohere.embeddings import CohereEmbeddings
|
20 |
from langchain_cohere.chat_models import ChatCohere
|
21 |
+
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
|
22 |
+
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
|
23 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
24 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
25 |
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
26 |
from langchain_community.chat_models import ChatPerplexity
|
27 |
from langchain_together import ChatTogether
|
28 |
from langchain_together.embeddings import TogetherEmbeddings
|
29 |
+
from langchain.chat_models.base import BaseChatModel
|
30 |
+
from langchain.embeddings.base import Embeddings
|
31 |
|
32 |
+
def split_provider_model(provider_model: str) -> Tuple[str, Optional[str]]:
|
33 |
+
parts = provider_model.split(":", 1)
|
34 |
provider = parts[0]
|
35 |
+
if len(parts) > 1:
|
36 |
+
model = parts[1] if parts[1] else None
|
37 |
+
else:
|
38 |
+
model = None
|
39 |
return provider, model
|
40 |
|
41 |
+
def get_model(provider_model: str, temperature: float = 0.7) -> BaseChatModel:
|
42 |
+
"""
|
43 |
+
Get a model from a provider and model name.
|
44 |
+
returns BaseChatModel
|
45 |
+
"""
|
46 |
provider, model = split_provider_model(provider_model)
|
47 |
+
try:
|
48 |
+
match provider.lower():
|
49 |
+
case 'anthropic':
|
50 |
+
if model is None:
|
51 |
+
model = "claude-3-sonnet-20240229"
|
52 |
+
chat_llm = ChatAnthropic(model=model, temperature=temperature)
|
53 |
+
case 'bedrock':
|
54 |
+
if model is None:
|
55 |
+
model = "us.anthropic.claude-3-5-haiku-20241022-v1:0"
|
56 |
+
chat_llm = ChatBedrockConverse(model=model, temperature=temperature)
|
57 |
+
case 'cohere':
|
58 |
+
if model is None:
|
59 |
+
model = 'command-r-plus'
|
60 |
+
chat_llm = ChatCohere(model=model, temperature=temperature)
|
61 |
+
case 'fireworks':
|
62 |
+
if model is None:
|
63 |
+
model = 'accounts/fireworks/models/llama-v3p1-8b-instruct'
|
64 |
+
chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=120000)
|
65 |
+
case 'googlegenerativeai':
|
66 |
+
if model is None:
|
67 |
+
model = "gemini-1.5-flash"
|
68 |
+
chat_llm = ChatGoogleGenerativeAI(model=model, temperature=temperature,
|
69 |
+
max_tokens=None, timeout=None, max_retries=2,)
|
70 |
+
case 'groq':
|
71 |
+
if model is None:
|
72 |
+
model = 'llama-3.1-8b-instant'
|
73 |
+
chat_llm = ChatGroq(model_name=model, temperature=temperature)
|
74 |
+
case 'huggingface' | 'hf':
|
75 |
+
if model is None:
|
76 |
+
model = 'mistralai/Mistral-Nemo-Instruct-2407'
|
77 |
+
llm = HuggingFaceEndpoint(
|
78 |
+
repo_id=model,
|
79 |
+
max_length=8192,
|
80 |
+
temperature=temperature,
|
81 |
+
huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_KEY"),
|
82 |
+
)
|
83 |
+
chat_llm = ChatHuggingFace(llm=llm)
|
84 |
+
case 'ollama':
|
85 |
+
if model is None:
|
86 |
+
model = 'llama3.1'
|
87 |
+
chat_llm = ChatOllama(model=model, temperature=temperature)
|
88 |
+
case 'openai':
|
89 |
+
if model is None:
|
90 |
+
model = "gpt-4o-mini"
|
91 |
+
chat_llm = ChatOpenAI(model=model, temperature=temperature)
|
92 |
+
case 'openrouter':
|
93 |
+
if model is None:
|
94 |
+
model = "google/gemini-flash-1.5-exp"
|
95 |
+
chat_llm = ChatOpenAI(model=model, temperature=temperature, base_url="https://openrouter.ai/api/v1", api_key=os.getenv("OPENROUTER_API_KEY"))
|
96 |
+
case 'mistralai' | 'mistral':
|
97 |
+
if model is None:
|
98 |
+
model = "open-mistral-nemo"
|
99 |
+
chat_llm = ChatMistralAI(model=model, temperature=temperature)
|
100 |
+
case 'perplexity':
|
101 |
+
if model is None:
|
102 |
+
model = 'llama-3.1-sonar-small-128k-online'
|
103 |
+
chat_llm = ChatPerplexity(model=model, temperature=temperature)
|
104 |
+
case 'together':
|
105 |
+
if model is None:
|
106 |
+
model = 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
|
107 |
+
chat_llm = ChatTogether(model=model, temperature=temperature)
|
108 |
+
case 'xai':
|
109 |
+
if model is None:
|
110 |
+
model = 'grok-beta'
|
111 |
+
chat_llm = ChatOpenAI(model=model,api_key=os.getenv("XAI_API_KEY"), base_url="https://api.x.ai/v1", temperature=temperature)
|
112 |
+
case _:
|
113 |
+
raise ValueError(f"Unknown LLM provider {provider}")
|
114 |
+
except Exception as e:
|
115 |
+
raise ValueError(f"Unexpected error with {provider}: {str(e)}")
|
116 |
|
117 |
return chat_llm
|
118 |
|
119 |
|
120 |
+
def get_embedding_model(provider_model: str) -> Embeddings:
|
121 |
provider, model = split_provider_model(provider_model)
|
122 |
+
match provider.lower():
|
123 |
case 'bedrock':
|
124 |
if model is None:
|
125 |
model = "amazon.titan-embed-text-v2:0"
|
126 |
embedding_model = BedrockEmbeddings(model_id=model)
|
127 |
case 'cohere':
|
128 |
if model is None:
|
129 |
+
model = "embed-multilingual-v3"
|
130 |
embedding_model = CohereEmbeddings(model=model)
|
131 |
case 'fireworks':
|
132 |
if model is None:
|
|
|
146 |
embedding_model = GoogleGenerativeAIEmbeddings(model=model)
|
147 |
case 'groq':
|
148 |
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
|
149 |
+
case 'huggingface' | 'hf':
|
150 |
+
if model is None:
|
151 |
+
model = 'sentence-transformers/all-MiniLM-L6-v2'
|
152 |
+
embedding_model = HuggingFaceInferenceAPIEmbeddings(model_name=model, api_key=os.getenv("HUGGINGFACE_API_KEY"))
|
153 |
+
case 'mistral':
|
154 |
+
if model is None:
|
155 |
+
model = "mistral-embed"
|
156 |
+
embedding_model = MistralAIEmbeddings(model=model)
|
157 |
case 'perplexity':
|
158 |
raise ValueError(f"Cannot use Perplexity for embedding model")
|
159 |
case 'together':
|
|
|
234 |
class TestGetModel(unittest.TestCase):
|
235 |
|
236 |
@patch('models.ChatBedrockConverse')
|
237 |
+
def test_bedrock_model_no_specific_model(self, mock_bedrock):
|
238 |
result = get_model('bedrock')
|
239 |
+
mock_bedrock.assert_called_once_with(model=None, temperature=0.0)
|
240 |
+
self.assertEqual(result, mock_bedrock.return_value)
|
241 |
+
|
242 |
+
@patch('models.ChatBedrockConverse')
|
243 |
+
def test_bedrock_model_with_specific_model(self, mock_bedrock):
|
244 |
+
result = get_model('bedrock:specific-model')
|
245 |
+
mock_bedrock.assert_called_once_with(model='specific-model', temperature=0.0)
|
246 |
self.assertEqual(result, mock_bedrock.return_value)
|
247 |
|
248 |
@patch('models.ChatCohere')
|