Eddie Pick commited on
Commit
bb1d601
·
1 Parent(s): 9438062
Files changed (2) hide show
  1. .gitignore +1 -0
  2. models.py +107 -63
.gitignore CHANGED
@@ -121,6 +121,7 @@ celerybeat.pid
121
 
122
  # Environments
123
  .env
 
124
  .venv
125
  env/
126
  venv/
 
121
 
122
  # Environments
123
  .env
124
+ .env.local
125
  .venv
126
  env/
127
  venv/
models.py CHANGED
@@ -1,11 +1,5 @@
1
  import os
2
- import json
3
- from langchain.schema import SystemMessage, HumanMessage
4
- from langchain.prompts.chat import (
5
- HumanMessagePromptTemplate,
6
- SystemMessagePromptTemplate,
7
- ChatPromptTemplate
8
- )
9
  from langchain.prompts.prompt import PromptTemplate
10
  from langchain.retrievers.multi_query import MultiQueryRetriever
11
 
@@ -17,83 +11,122 @@ from langchain_fireworks.embeddings import FireworksEmbeddings
17
  from langchain_groq.chat_models import ChatGroq
18
  from langchain_openai import ChatOpenAI
19
  from langchain_openai.embeddings import OpenAIEmbeddings
 
 
 
20
  from langchain_ollama.chat_models import ChatOllama
21
  from langchain_ollama.embeddings import OllamaEmbeddings
22
  from langchain_cohere.embeddings import CohereEmbeddings
23
  from langchain_cohere.chat_models import ChatCohere
 
 
24
  from langchain_openai.embeddings import OpenAIEmbeddings
25
  from langchain_google_genai import ChatGoogleGenerativeAI
26
  from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
27
  from langchain_community.chat_models import ChatPerplexity
28
  from langchain_together import ChatTogether
29
  from langchain_together.embeddings import TogetherEmbeddings
 
 
30
 
31
- def split_provider_model(provider_model):
32
- parts = provider_model.split(':', 1)
33
  provider = parts[0]
34
- model = parts[1] if len(parts) > 1 else None
 
 
 
35
  return provider, model
36
 
37
- def get_model(provider_model, temperature=0.0):
 
 
 
 
38
  provider, model = split_provider_model(provider_model)
39
- match provider:
40
- case 'bedrock':
41
- if model is None:
42
- model = "anthropic.claude-3-sonnet-20240229-v1:0"
43
- chat_llm = ChatBedrockConverse(model=model, temperature=temperature)
44
- case 'cohere':
45
- if model is None:
46
- model = 'command-r-plus'
47
- chat_llm = ChatCohere(model=model, temperature=temperature)
48
- case 'fireworks':
49
- if model is None:
50
- model = 'accounts/fireworks/models/llama-v3p1-8b-instruct'
51
- chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=120000)
52
- case 'googlegenerativeai':
53
- if model is None:
54
- model = "gemini-1.5-flash"
55
- chat_llm = ChatGoogleGenerativeAI(model=model, temperature=temperature,
56
- max_tokens=None, timeout=None, max_retries=2,)
57
- case 'groq':
58
- if model is None:
59
- model = 'llama-3.1-8b-instant'
60
- chat_llm = ChatGroq(model_name=model, temperature=temperature)
61
- case 'ollama':
62
- if model is None:
63
- model = 'llama3.1'
64
- chat_llm = ChatOllama(model=model, temperature=temperature)
65
- case 'openai':
66
- if model is None:
67
- model = "gpt-4o-mini"
68
- chat_llm = ChatOpenAI(model=model, temperature=temperature)
69
- case 'openrouter':
70
- if model is None:
71
- model = "google/gemini-flash-1.5-exp"
72
- chat_llm = ChatOpenAI(model=model, temperature=temperature, base_url="https://openrouter.ai/api/v1", api_key=os.getenv("OPENROUTER_API_KEY"))
73
- case 'perplexity':
74
- if model is None:
75
- model = 'llama-3.1-sonar-small-128k-online'
76
- chat_llm = ChatPerplexity(model=model, temperature=temperature)
77
- case 'together':
78
- if model is None:
79
- model = 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
80
- chat_llm = ChatTogether(model=model, temperature=temperature)
81
- case _:
82
- raise ValueError(f"Unknown LLM provider {provider}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  return chat_llm
85
 
86
 
87
- def get_embedding_model(provider_model):
88
  provider, model = split_provider_model(provider_model)
89
- match provider:
90
  case 'bedrock':
91
  if model is None:
92
  model = "amazon.titan-embed-text-v2:0"
93
  embedding_model = BedrockEmbeddings(model_id=model)
94
  case 'cohere':
95
  if model is None:
96
- model = "embed-english-light-v3.0"
97
  embedding_model = CohereEmbeddings(model=model)
98
  case 'fireworks':
99
  if model is None:
@@ -113,6 +146,14 @@ def get_embedding_model(provider_model):
113
  embedding_model = GoogleGenerativeAIEmbeddings(model=model)
114
  case 'groq':
115
  embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
 
 
 
 
 
 
 
 
116
  case 'perplexity':
117
  raise ValueError(f"Cannot use Perplexity for embedding model")
118
  case 'together':
@@ -193,12 +234,15 @@ from models import get_model # Make sure this import is correct
193
  class TestGetModel(unittest.TestCase):
194
 
195
  @patch('models.ChatBedrockConverse')
196
- def test_bedrock_model(self, mock_bedrock):
197
  result = get_model('bedrock')
198
- mock_bedrock.assert_called_once_with(
199
- model="anthropic.claude-3-sonnet-20240229-v1:0",
200
- temperature=0.0
201
- )
 
 
 
202
  self.assertEqual(result, mock_bedrock.return_value)
203
 
204
  @patch('models.ChatCohere')
 
1
  import os
2
+ from typing import Tuple, Optional
 
 
 
 
 
 
3
  from langchain.prompts.prompt import PromptTemplate
4
  from langchain.retrievers.multi_query import MultiQueryRetriever
5
 
 
11
  from langchain_groq.chat_models import ChatGroq
12
  from langchain_openai import ChatOpenAI
13
  from langchain_openai.embeddings import OpenAIEmbeddings
14
+ from langchain_anthropic.chat_models import ChatAnthropic
15
+ from langchain_mistralai.chat_models import ChatMistralAI
16
+ from langchain_mistralai.embeddings import MistralAIEmbeddings
17
  from langchain_ollama.chat_models import ChatOllama
18
  from langchain_ollama.embeddings import OllamaEmbeddings
19
  from langchain_cohere.embeddings import CohereEmbeddings
20
  from langchain_cohere.chat_models import ChatCohere
21
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
22
+ from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
23
  from langchain_openai.embeddings import OpenAIEmbeddings
24
  from langchain_google_genai import ChatGoogleGenerativeAI
25
  from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
26
  from langchain_community.chat_models import ChatPerplexity
27
  from langchain_together import ChatTogether
28
  from langchain_together.embeddings import TogetherEmbeddings
29
+ from langchain.chat_models.base import BaseChatModel
30
+ from langchain.embeddings.base import Embeddings
31
 
32
+ def split_provider_model(provider_model: str) -> Tuple[str, Optional[str]]:
33
+ parts = provider_model.split(":", 1)
34
  provider = parts[0]
35
+ if len(parts) > 1:
36
+ model = parts[1] if parts[1] else None
37
+ else:
38
+ model = None
39
  return provider, model
40
 
41
+ def get_model(provider_model: str, temperature: float = 0.7) -> BaseChatModel:
42
+ """
43
+ Get a model from a provider and model name.
44
+ returns BaseChatModel
45
+ """
46
  provider, model = split_provider_model(provider_model)
47
+ try:
48
+ match provider.lower():
49
+ case 'anthropic':
50
+ if model is None:
51
+ model = "claude-3-sonnet-20240229"
52
+ chat_llm = ChatAnthropic(model=model, temperature=temperature)
53
+ case 'bedrock':
54
+ if model is None:
55
+ model = "us.anthropic.claude-3-5-haiku-20241022-v1:0"
56
+ chat_llm = ChatBedrockConverse(model=model, temperature=temperature)
57
+ case 'cohere':
58
+ if model is None:
59
+ model = 'command-r-plus'
60
+ chat_llm = ChatCohere(model=model, temperature=temperature)
61
+ case 'fireworks':
62
+ if model is None:
63
+ model = 'accounts/fireworks/models/llama-v3p1-8b-instruct'
64
+ chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=120000)
65
+ case 'googlegenerativeai':
66
+ if model is None:
67
+ model = "gemini-1.5-flash"
68
+ chat_llm = ChatGoogleGenerativeAI(model=model, temperature=temperature,
69
+ max_tokens=None, timeout=None, max_retries=2,)
70
+ case 'groq':
71
+ if model is None:
72
+ model = 'llama-3.1-8b-instant'
73
+ chat_llm = ChatGroq(model_name=model, temperature=temperature)
74
+ case 'huggingface' | 'hf':
75
+ if model is None:
76
+ model = 'mistralai/Mistral-Nemo-Instruct-2407'
77
+ llm = HuggingFaceEndpoint(
78
+ repo_id=model,
79
+ max_length=8192,
80
+ temperature=temperature,
81
+ huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_KEY"),
82
+ )
83
+ chat_llm = ChatHuggingFace(llm=llm)
84
+ case 'ollama':
85
+ if model is None:
86
+ model = 'llama3.1'
87
+ chat_llm = ChatOllama(model=model, temperature=temperature)
88
+ case 'openai':
89
+ if model is None:
90
+ model = "gpt-4o-mini"
91
+ chat_llm = ChatOpenAI(model=model, temperature=temperature)
92
+ case 'openrouter':
93
+ if model is None:
94
+ model = "google/gemini-flash-1.5-exp"
95
+ chat_llm = ChatOpenAI(model=model, temperature=temperature, base_url="https://openrouter.ai/api/v1", api_key=os.getenv("OPENROUTER_API_KEY"))
96
+ case 'mistralai' | 'mistral':
97
+ if model is None:
98
+ model = "open-mistral-nemo"
99
+ chat_llm = ChatMistralAI(model=model, temperature=temperature)
100
+ case 'perplexity':
101
+ if model is None:
102
+ model = 'llama-3.1-sonar-small-128k-online'
103
+ chat_llm = ChatPerplexity(model=model, temperature=temperature)
104
+ case 'together':
105
+ if model is None:
106
+ model = 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
107
+ chat_llm = ChatTogether(model=model, temperature=temperature)
108
+ case 'xai':
109
+ if model is None:
110
+ model = 'grok-beta'
111
+ chat_llm = ChatOpenAI(model=model,api_key=os.getenv("XAI_API_KEY"), base_url="https://api.x.ai/v1", temperature=temperature)
112
+ case _:
113
+ raise ValueError(f"Unknown LLM provider {provider}")
114
+ except Exception as e:
115
+ raise ValueError(f"Unexpected error with {provider}: {str(e)}")
116
 
117
  return chat_llm
118
 
119
 
120
+ def get_embedding_model(provider_model: str) -> Embeddings:
121
  provider, model = split_provider_model(provider_model)
122
+ match provider.lower():
123
  case 'bedrock':
124
  if model is None:
125
  model = "amazon.titan-embed-text-v2:0"
126
  embedding_model = BedrockEmbeddings(model_id=model)
127
  case 'cohere':
128
  if model is None:
129
+ model = "embed-multilingual-v3"
130
  embedding_model = CohereEmbeddings(model=model)
131
  case 'fireworks':
132
  if model is None:
 
146
  embedding_model = GoogleGenerativeAIEmbeddings(model=model)
147
  case 'groq':
148
  embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
149
+ case 'huggingface' | 'hf':
150
+ if model is None:
151
+ model = 'sentence-transformers/all-MiniLM-L6-v2'
152
+ embedding_model = HuggingFaceInferenceAPIEmbeddings(model_name=model, api_key=os.getenv("HUGGINGFACE_API_KEY"))
153
+ case 'mistral':
154
+ if model is None:
155
+ model = "mistral-embed"
156
+ embedding_model = MistralAIEmbeddings(model=model)
157
  case 'perplexity':
158
  raise ValueError(f"Cannot use Perplexity for embedding model")
159
  case 'together':
 
234
  class TestGetModel(unittest.TestCase):
235
 
236
  @patch('models.ChatBedrockConverse')
237
+ def test_bedrock_model_no_specific_model(self, mock_bedrock):
238
  result = get_model('bedrock')
239
+ mock_bedrock.assert_called_once_with(model=None, temperature=0.0)
240
+ self.assertEqual(result, mock_bedrock.return_value)
241
+
242
+ @patch('models.ChatBedrockConverse')
243
+ def test_bedrock_model_with_specific_model(self, mock_bedrock):
244
+ result = get_model('bedrock:specific-model')
245
+ mock_bedrock.assert_called_once_with(model='specific-model', temperature=0.0)
246
  self.assertEqual(result, mock_bedrock.return_value)
247
 
248
  @patch('models.ChatCohere')