search_agent / models.py
CyranoB's picture
Fixed web UI
3f21537
import os
from typing import Tuple, Optional
from langchain.prompts.prompt import PromptTemplate
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_aws import BedrockEmbeddings
from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse
from langchain_cohere import ChatCohere
from langchain_fireworks.chat_models import ChatFireworks
from langchain_fireworks.embeddings import FireworksEmbeddings
from langchain_groq.chat_models import ChatGroq
from langchain_openai import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_anthropic.chat_models import ChatAnthropic
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_mistralai.embeddings import MistralAIEmbeddings
from langchain_ollama.chat_models import ChatOllama
from langchain_ollama.embeddings import OllamaEmbeddings
from langchain_cohere.embeddings import CohereEmbeddings
from langchain_cohere.chat_models import ChatCohere
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from langchain_community.chat_models import ChatPerplexity
from langchain_together import ChatTogether
from langchain_together.embeddings import TogetherEmbeddings
from langchain.chat_models.base import BaseChatModel
from langchain.embeddings.base import Embeddings
def split_provider_model(provider_model: str) -> Tuple[str, Optional[str]]:
"""
Split the provider and model name from a string.
returns Tuple[str, Optional[str]]
"""
parts = provider_model.split(":", 1)
provider = parts[0]
if len(parts) > 1:
model = parts[1] if parts[1] else None
else:
model = None
return provider, model
def get_model(provider_model: str, temperature: float = 0.7) -> BaseChatModel:
"""
Get a model from a provider and model name.
returns BaseChatModel
"""
provider, model = split_provider_model(provider_model)
try:
match provider.lower():
case 'anthropic':
if model is None:
model = "claude-3-5-haiku-20241022"
chat_llm = ChatAnthropic(model=model, temperature=temperature)
case 'bedrock':
if model is None:
model = "us.anthropic.claude-3-5-haiku-20241022-v1:0"
chat_llm = ChatBedrockConverse(model=model, temperature=temperature)
case 'cohere':
if model is None:
model = 'command-r-plus'
chat_llm = ChatCohere(model=model, temperature=temperature)
case 'deepseek':
if model is None:
model='deepseek-chat'
chat_llm = ChatOpenAI(
model=model,
openai_api_key=os.getenv("DEEPSEEK_API_KEY"),
openai_api_base='https://api.deepseek.com',
max_tokens=8192
)
case 'fireworks':
if model is None:
model = 'accounts/fireworks/models/llama-v3p3-70b-instruct'
chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=120000)
case 'googlegenerativeai':
if model is None:
model = "gemini-2.0-flash-exp"
chat_llm = ChatGoogleGenerativeAI(model=model, temperature=temperature,
max_tokens=None, timeout=None, max_retries=2,)
case 'groq':
if model is None:
model = 'qwen-2.5-32b'
chat_llm = ChatGroq(model_name=model, temperature=temperature)
case 'huggingface' | 'hf':
if model is None:
model = 'Qwen/Qwen2.5-72B-Instruct'
llm = HuggingFaceEndpoint(
repo_id=model,
temperature=temperature,
)
chat_llm = ChatHuggingFace(llm=llm)
case 'ollama':
if model is None:
model = 'llama3.1'
chat_llm = ChatOllama(model=model, temperature=temperature)
case 'openai':
if model is None:
model = "gpt-4o-mini"
chat_llm = ChatOpenAI(model=model, temperature=temperature)
case 'openrouter':
if model is None:
model = "cognitivecomputations/dolphin3.0-mistral-24b:free"
chat_llm = ChatOpenAI(model=model, temperature=temperature, base_url="https://openrouter.ai/api/v1", api_key=os.getenv("OPENROUTER_API_KEY"))
case 'mistralai' | 'mistral':
if model is None:
model = "mistral-small-latest"
chat_llm = ChatMistralAI(model=model, temperature=temperature)
case 'perplexity':
if model is None:
model = 'sonar'
chat_llm = ChatPerplexity(model=model, temperature=temperature)
case 'together':
if model is None:
model = 'meta-llama/Llama-3.3-70B-Instruct-Turbo-Free'
chat_llm = ChatTogether(model=model, temperature=temperature)
case 'xai':
if model is None:
model = 'grok-2-1212'
chat_llm = ChatOpenAI(model=model,api_key=os.getenv("XAI_API_KEY"), base_url="https://api.x.ai/v1", temperature=temperature)
case _:
raise ValueError(f"Unknown LLM provider {provider}")
except Exception as e:
raise ValueError(f"Unexpected error with {provider}: {str(e)}")
return chat_llm
def get_embedding_model(provider_model: str) -> Embeddings:
"""
Get an embedding model from a provider and model name.
returns Embeddings
"""
provider, model = split_provider_model(provider_model)
match provider.lower():
case 'bedrock':
if model is None:
model = "amazon.titan-embed-text-v2:0"
embedding_model = BedrockEmbeddings(model_id=model)
case 'cohere':
if model is None:
model = "embed-multilingual-v3.0"
embedding_model = CohereEmbeddings(model=model)
case 'fireworks':
if model is None:
model = 'nomic-ai/nomic-embed-text-v1.5'
embedding_model = FireworksEmbeddings(model=model)
case 'ollama':
if model is None:
model = 'nomic-embed-text:latest'
embedding_model = OllamaEmbeddings(model=model)
case 'openai':
if model is None:
model = "text-embedding-3-small"
embedding_model = OpenAIEmbeddings(model=model)
case 'googlegenerativeai':
if model is None:
model = "models/embedding-001"
embedding_model = GoogleGenerativeAIEmbeddings(model=model)
case 'groq':
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
case 'huggingface' | 'hf':
if model is None:
model = 'sentence-transformers/all-MiniLM-L6-v2'
embedding_model = HuggingFaceInferenceAPIEmbeddings(model_name=model, api_key=os.getenv("HUGGINGFACE_API_KEY"))
case 'mistral':
if model is None:
model = "mistral-embed"
embedding_model = MistralAIEmbeddings(model=model)
case 'perplexity':
raise ValueError(f"Cannot use Perplexity for embedding model")
case 'together':
if model is None:
model = 'togethercomputer/m2-bert-80M-2k-retrieval'
embedding_model = TogetherEmbeddings(model=model)
case _:
raise ValueError(f"Unknown LLM provider {provider}")
return embedding_model
import unittest
from unittest.mock import patch
from models import get_embedding_model # Make sure this import is correct
class TestGetEmbeddingModel(unittest.TestCase):
@patch('models.BedrockEmbeddings')
def test_bedrock_embedding(self, mock_bedrock):
result = get_embedding_model('bedrock')
mock_bedrock.assert_called_once_with(model_id='cohere.embed-multilingual-v3')
self.assertEqual(result, mock_bedrock.return_value)
@patch('models.CohereEmbeddings')
def test_cohere_embedding(self, mock_cohere):
result = get_embedding_model('cohere')
mock_cohere.assert_called_once_with(model='embed-english-light-v3.0')
self.assertEqual(result, mock_cohere.return_value)
@patch('models.FireworksEmbeddings')
def test_fireworks_embedding(self, mock_fireworks):
result = get_embedding_model('fireworks')
mock_fireworks.assert_called_once_with(model='nomic-ai/nomic-embed-text-v1.5')
self.assertEqual(result, mock_fireworks.return_value)
@patch('models.OllamaEmbeddings')
def test_ollama_embedding(self, mock_ollama):
result = get_embedding_model('ollama')
mock_ollama.assert_called_once_with(model='nomic-embed-text:latest')
self.assertEqual(result, mock_ollama.return_value)
@patch('models.OpenAIEmbeddings')
def test_openai_embedding(self, mock_openai):
result = get_embedding_model('openai')
mock_openai.assert_called_once_with(model='text-embedding-3-small')
self.assertEqual(result, mock_openai.return_value)
@patch('models.GoogleGenerativeAIEmbeddings')
def test_google_embedding(self, mock_google):
result = get_embedding_model('googlegenerativeai')
mock_google.assert_called_once_with(model='models/embedding-001')
self.assertEqual(result, mock_google.return_value)
@patch('models.TogetherEmbeddings')
def test_together_embedding(self, mock_together):
result = get_embedding_model('together')
mock_together.assert_called_once_with(model='BAAI/bge-base-en-v1.5')
self.assertEqual(result, mock_together.return_value)
def test_invalid_provider(self):
with self.assertRaises(ValueError):
get_embedding_model('invalid_provider')
def test_groq_provider(self):
with self.assertRaises(ValueError):
get_embedding_model('groq')
def test_perplexity_provider(self):
with self.assertRaises(ValueError):
get_embedding_model('perplexity')
import unittest
from unittest.mock import patch
from models import get_model # Make sure this import is correct
class TestGetModel(unittest.TestCase):
@patch('models.ChatBedrockConverse')
def test_bedrock_model_no_specific_model(self, mock_bedrock):
result = get_model('bedrock')
mock_bedrock.assert_called_once_with(model=None, temperature=0.0)
self.assertEqual(result, mock_bedrock.return_value)
@patch('models.ChatBedrockConverse')
def test_bedrock_model_with_specific_model(self, mock_bedrock):
result = get_model('bedrock:specific-model')
mock_bedrock.assert_called_once_with(model='specific-model', temperature=0.0)
self.assertEqual(result, mock_bedrock.return_value)
@patch('models.ChatCohere')
def test_cohere_model(self, mock_cohere):
result = get_model('cohere')
mock_cohere.assert_called_once_with(model='command-r-plus', temperature=0.0)
self.assertEqual(result, mock_cohere.return_value)
@patch('models.ChatFireworks')
def test_fireworks_model(self, mock_fireworks):
result = get_model('fireworks')
mock_fireworks.assert_called_once_with(
model_name='accounts/fireworks/models/llama-v3p1-8b-instruct',
temperature=0.0,
max_tokens=120000
)
self.assertEqual(result, mock_fireworks.return_value)
@patch('models.ChatGoogleGenerativeAI')
def test_google_model(self, mock_google):
result = get_model('googlegenerativeai')
mock_google.assert_called_once_with(
model="gemini-1.5-pro",
temperature=0.0,
max_tokens=None,
timeout=None,
max_retries=2
)
self.assertEqual(result, mock_google.return_value)
@patch('models.ChatGroq')
def test_groq_model(self, mock_groq):
result = get_model('groq')
mock_groq.assert_called_once_with(model_name='llama-3.1-8b-instant', temperature=0.0)
self.assertEqual(result, mock_groq.return_value)
@patch('models.ChatOllama')
def test_ollama_model(self, mock_ollama):
result = get_model('ollama')
mock_ollama.assert_called_once_with(model='llama3.1', temperature=0.0)
self.assertEqual(result, mock_ollama.return_value)
@patch('models.ChatOpenAI')
def test_openai_model(self, mock_openai):
result = get_model('openai')
mock_openai.assert_called_once_with(model_name='gpt-4o-mini', temperature=0.0)
self.assertEqual(result, mock_openai.return_value)
@patch('models.ChatPerplexity')
def test_perplexity_model(self, mock_perplexity):
result = get_model('perplexity')
mock_perplexity.assert_called_once_with(model='llama-3.1-sonar-small-128k-online', temperature=0.0)
self.assertEqual(result, mock_perplexity.return_value)
@patch('models.ChatTogether')
def test_together_model(self, mock_together):
result = get_model('together')
mock_together.assert_called_once_with(model='meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo', temperature=0.0)
self.assertEqual(result, mock_together.return_value)
def test_invalid_provider(self):
with self.assertRaises(ValueError):
get_model('invalid_provider')
def test_custom_temperature(self):
with patch('models.ChatOpenAI') as mock_openai:
result = get_model('openai', temperature=0.5)
mock_openai.assert_called_once_with(model_name='gpt-4o-mini', temperature=0.5)
self.assertEqual(result, mock_openai.return_value)
def test_custom_model(self):
with patch('models.ChatOpenAI') as mock_openai:
result = get_model('openai/gpt-4')
mock_openai.assert_called_once_with(model_name='gpt-4', temperature=0.0)
self.assertEqual(result, mock_openai.return_value)
if __name__ == '__main__':
unittest.main()