slide-deck-ai / helpers /llm_helper.py
barunsaha's picture
Allow users to provide their own HF access token/API key
69fbdcb
raw
history blame
3.41 kB
import logging
import re
from typing import Tuple, Union
import requests
from requests.adapters import HTTPAdapter
from urllib3.util import Retry
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_core.language_models import LLM
from global_config import GlobalConfig
LLM_PROVIDER_MODEL_REGEX = re.compile(r'\[(.*?)\](.*)')
HF_API_HEADERS = {'Authorization': f'Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}'}
REQUEST_TIMEOUT = 35
logger = logging.getLogger(__name__)
retries = Retry(
total=5,
backoff_factor=0.25,
backoff_jitter=0.3,
status_forcelist=[502, 503, 504],
allowed_methods={'POST'},
)
adapter = HTTPAdapter(max_retries=retries)
http_session = requests.Session()
http_session.mount('https://', adapter)
http_session.mount('http://', adapter)
def get_provider_model(provider_model: str) -> Tuple[str, str]:
"""
Parse and get LLM provider and model name from strings like `[provider]model/name-version`.
:param provider_model: The provider, model name string from `GlobalConfig`.
:return: The provider and the model name.
"""
match = LLM_PROVIDER_MODEL_REGEX.match(provider_model)
if match:
inside_brackets = match.group(1)
outside_brackets = match.group(2)
return inside_brackets, outside_brackets
return '', ''
def get_hf_endpoint(repo_id: str, max_new_tokens: int, api_key: str = '') -> LLM:
"""
Get an LLM via the HuggingFaceEndpoint of LangChain.
:param repo_id: The model name.
:param max_new_tokens: The max new tokens to generate.
:param api_key: [Optional] Hugging Face access token.
:return: The HF LLM inference endpoint.
"""
logger.debug('Getting LLM via HF endpoint: %s', repo_id)
return HuggingFaceEndpoint(
repo_id=repo_id,
max_new_tokens=max_new_tokens,
top_k=40,
top_p=0.95,
temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
repetition_penalty=1.03,
streaming=True,
huggingfacehub_api_token=api_key or GlobalConfig.HUGGINGFACEHUB_API_TOKEN,
return_full_text=False,
stop_sequences=['</s>'],
)
def get_langchain_llm(
provider: str,
model: str,
max_new_tokens: int,
api_key: str = ''
) -> Union[LLM, None]:
"""
Get an LLM based on the provider and model specified.
:param provider: The LLM provider. Valid values are `hf` for Hugging Face.
:param model:
:param max_new_tokens:
:param api_key:
:return:
"""
if not provider or not model or provider not in GlobalConfig.VALID_PROVIDERS:
return None
if provider == 'hf':
logger.debug('Getting LLM via HF endpoint: %s', model)
return HuggingFaceEndpoint(
repo_id=model,
max_new_tokens=max_new_tokens,
top_k=40,
top_p=0.95,
temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
repetition_penalty=1.03,
streaming=True,
huggingfacehub_api_token=api_key or GlobalConfig.HUGGINGFACEHUB_API_TOKEN,
return_full_text=False,
stop_sequences=['</s>'],
)
return None
if __name__ == '__main__':
inputs = [
'[hf]mistralai/Mistral-7B-Instruct-v0.2',
'[gg]gemini-1.5-flash-002'
]
for text in inputs:
print(get_provider_model(text))