Spaces:
Running
Running
File size: 3,406 Bytes
aa4f694 69fbdcb 8537019 e690364 9c0dccd 3e68ccf 69fbdcb 4bd6659 724babe 9c0dccd aa4f694 e690364 9c0dccd 69fbdcb 9c0dccd e690364 9c0dccd 813ce6e 69fbdcb 813ce6e 9c0dccd 813ce6e 9c0dccd 813ce6e 9c0dccd 69fbdcb 9c0dccd 3e68ccf 69fbdcb 3e68ccf 8537019 69fbdcb 3e68ccf 69fbdcb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import logging
import re
from typing import Tuple, Union
import requests
from requests.adapters import HTTPAdapter
from urllib3.util import Retry
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_core.language_models import LLM
from global_config import GlobalConfig
LLM_PROVIDER_MODEL_REGEX = re.compile(r'\[(.*?)\](.*)')
HF_API_HEADERS = {'Authorization': f'Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}'}
REQUEST_TIMEOUT = 35
logger = logging.getLogger(__name__)
retries = Retry(
total=5,
backoff_factor=0.25,
backoff_jitter=0.3,
status_forcelist=[502, 503, 504],
allowed_methods={'POST'},
)
adapter = HTTPAdapter(max_retries=retries)
http_session = requests.Session()
http_session.mount('https://', adapter)
http_session.mount('http://', adapter)
def get_provider_model(provider_model: str) -> Tuple[str, str]:
"""
Parse and get LLM provider and model name from strings like `[provider]model/name-version`.
:param provider_model: The provider, model name string from `GlobalConfig`.
:return: The provider and the model name.
"""
match = LLM_PROVIDER_MODEL_REGEX.match(provider_model)
if match:
inside_brackets = match.group(1)
outside_brackets = match.group(2)
return inside_brackets, outside_brackets
return '', ''
def get_hf_endpoint(repo_id: str, max_new_tokens: int, api_key: str = '') -> LLM:
"""
Get an LLM via the HuggingFaceEndpoint of LangChain.
:param repo_id: The model name.
:param max_new_tokens: The max new tokens to generate.
:param api_key: [Optional] Hugging Face access token.
:return: The HF LLM inference endpoint.
"""
logger.debug('Getting LLM via HF endpoint: %s', repo_id)
return HuggingFaceEndpoint(
repo_id=repo_id,
max_new_tokens=max_new_tokens,
top_k=40,
top_p=0.95,
temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
repetition_penalty=1.03,
streaming=True,
huggingfacehub_api_token=api_key or GlobalConfig.HUGGINGFACEHUB_API_TOKEN,
return_full_text=False,
stop_sequences=['</s>'],
)
def get_langchain_llm(
provider: str,
model: str,
max_new_tokens: int,
api_key: str = ''
) -> Union[LLM, None]:
"""
Get an LLM based on the provider and model specified.
:param provider: The LLM provider. Valid values are `hf` for Hugging Face.
:param model:
:param max_new_tokens:
:param api_key:
:return:
"""
if not provider or not model or provider not in GlobalConfig.VALID_PROVIDERS:
return None
if provider == 'hf':
logger.debug('Getting LLM via HF endpoint: %s', model)
return HuggingFaceEndpoint(
repo_id=model,
max_new_tokens=max_new_tokens,
top_k=40,
top_p=0.95,
temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
repetition_penalty=1.03,
streaming=True,
huggingfacehub_api_token=api_key or GlobalConfig.HUGGINGFACEHUB_API_TOKEN,
return_full_text=False,
stop_sequences=['</s>'],
)
return None
if __name__ == '__main__':
inputs = [
'[hf]mistralai/Mistral-7B-Instruct-v0.2',
'[gg]gemini-1.5-flash-002'
]
for text in inputs:
print(get_provider_model(text))
|