""" Helper functions to access LLMs. """ import logging import re import sys from typing import Tuple, Union import requests from requests.adapters import HTTPAdapter from urllib3.util import Retry from langchain_core.language_models import BaseLLM sys.path.append('..') from global_config import GlobalConfig LLM_PROVIDER_MODEL_REGEX = re.compile(r'\[(.*?)\](.*)') # 6-64 characters long, only containing alphanumeric characters, hyphens, and underscores API_KEY_REGEX = re.compile(r'^[a-zA-Z0-9\-_]{6,64}$') HF_API_HEADERS = {'Authorization': f'Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}'} REQUEST_TIMEOUT = 35 logger = logging.getLogger(__name__) logging.getLogger('httpx').setLevel(logging.WARNING) logging.getLogger('httpcore').setLevel(logging.WARNING) retries = Retry( total=5, backoff_factor=0.25, backoff_jitter=0.3, status_forcelist=[502, 503, 504], allowed_methods={'POST'}, ) adapter = HTTPAdapter(max_retries=retries) http_session = requests.Session() http_session.mount('https://', adapter) http_session.mount('http://', adapter) def get_provider_model(provider_model: str) -> Tuple[str, str]: """ Parse and get LLM provider and model name from strings like `[provider]model/name-version`. :param provider_model: The provider, model name string from `GlobalConfig`. :return: The provider and the model name. """ match = LLM_PROVIDER_MODEL_REGEX.match(provider_model) if match: inside_brackets = match.group(1) outside_brackets = match.group(2) return inside_brackets, outside_brackets return '', '' def is_valid_llm_provider_model(provider: str, model: str, api_key: str) -> bool: """ Verify whether LLM settings are proper. This function does not verify whether `api_key` is correct. It only confirms that the key has at least five characters. Key verification is done when the LLM is created. :param provider: Name of the LLM provider. :param model: Name of the model. :param api_key: The API key or access token. :return: `True` if the settings "look" OK; `False` otherwise. """ if not provider or not model or provider not in GlobalConfig.VALID_PROVIDERS: return False if provider in [ GlobalConfig.PROVIDER_GOOGLE_GEMINI, GlobalConfig.PROVIDER_COHERE, ] and not api_key: return False if api_key: return API_KEY_REGEX.match(api_key) is not None return True def get_langchain_llm( provider: str, model: str, max_new_tokens: int, api_key: str = '' ) -> Union[BaseLLM, None]: """ Get an LLM based on the provider and model specified. :param provider: The LLM provider. Valid values are `hf` for Hugging Face. :param model: The name of the LLM. :param max_new_tokens: The maximum number of tokens to generate. :param api_key: API key or access token to use. :return: An instance of the LLM or `None` in case of any error. """ if provider == GlobalConfig.PROVIDER_HUGGING_FACE: from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint logger.debug('Getting LLM via HF endpoint: %s', model) return HuggingFaceEndpoint( repo_id=model, max_new_tokens=max_new_tokens, top_k=40, top_p=0.95, temperature=GlobalConfig.LLM_MODEL_TEMPERATURE, repetition_penalty=1.03, streaming=True, huggingfacehub_api_token=api_key or GlobalConfig.HUGGINGFACEHUB_API_TOKEN, return_full_text=False, stop_sequences=[''], ) if provider == GlobalConfig.PROVIDER_GOOGLE_GEMINI: from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory from langchain_google_genai import GoogleGenerativeAI logger.debug('Getting LLM via Google Gemini: %s', model) return GoogleGenerativeAI( model=model, temperature=GlobalConfig.LLM_MODEL_TEMPERATURE, max_tokens=max_new_tokens, timeout=None, max_retries=2, google_api_key=api_key, safety_settings={ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE } ) if provider == GlobalConfig.PROVIDER_COHERE: from langchain_cohere.llms import Cohere logger.debug('Getting LLM via Cohere: %s', model) return Cohere( temperature=GlobalConfig.LLM_MODEL_TEMPERATURE, max_tokens=max_new_tokens, timeout_seconds=None, max_retries=2, cohere_api_key=api_key, streaming=True, ) return None if __name__ == '__main__': inputs = [ '[co]Cohere', '[hf]mistralai/Mistral-7B-Instruct-v0.2', '[gg]gemini-1.5-flash-002' ] for text in inputs: print(get_provider_model(text))