slide-deck-ai / helpers /llm_helper.py
barunsaha's picture
Allow users to choose from two different Mistral models
813ce6e
raw
history blame
3.77 kB
import logging
import requests
from requests.adapters import HTTPAdapter
from urllib3.util import Retry
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_core.language_models import LLM
from global_config import GlobalConfig
HF_API_HEADERS = {"Authorization": f"Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}"}
REQUEST_TIMEOUT = 35
logger = logging.getLogger(__name__)
retries = Retry(
total=5,
backoff_factor=0.25,
backoff_jitter=0.3,
status_forcelist=[502, 503, 504],
allowed_methods={'POST'},
)
adapter = HTTPAdapter(max_retries=retries)
http_session = requests.Session()
http_session.mount('https://', adapter)
http_session.mount('http://', adapter)
def get_hf_endpoint(repo_id: str, max_new_tokens: int) -> LLM:
"""
Get an LLM via the HuggingFaceEndpoint of LangChain.
:param repo_id: The model name.
:param max_new_tokens: The max new tokens to generate.
:return: The HF LLM inference endpoint.
"""
logger.debug('Getting LLM via HF endpoint: %s', repo_id)
return HuggingFaceEndpoint(
repo_id=repo_id,
max_new_tokens=max_new_tokens,
top_k=40,
top_p=0.95,
temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
repetition_penalty=1.03,
streaming=True,
huggingfacehub_api_token=GlobalConfig.HUGGINGFACEHUB_API_TOKEN,
return_full_text=False,
stop_sequences=['</s>'],
)
# def hf_api_query(payload: dict) -> dict:
# """
# Invoke HF inference end-point API.
#
# :param payload: The prompt for the LLM and related parameters.
# :return: The output from the LLM.
# """
#
# try:
# response = http_session.post(
# HF_API_URL,
# headers=HF_API_HEADERS,
# json=payload,
# timeout=REQUEST_TIMEOUT
# )
# result = response.json()
# except requests.exceptions.Timeout as te:
# logger.error('*** Error: hf_api_query timeout! %s', str(te))
# result = []
#
# return result
# def generate_slides_content(topic: str) -> str:
# """
# Generate the outline/contents of slides for a presentation on a given topic.
#
# :param topic: Topic on which slides are to be generated.
# :return: The content in JSON format.
# """
#
# with open(GlobalConfig.SLIDES_TEMPLATE_FILE, 'r', encoding='utf-8') as in_file:
# template_txt = in_file.read().strip()
# template_txt = template_txt.replace('<REPLACE_PLACEHOLDER>', topic)
#
# output = hf_api_query({
# 'inputs': template_txt,
# 'parameters': {
# 'temperature': GlobalConfig.LLM_MODEL_TEMPERATURE,
# 'min_length': GlobalConfig.LLM_MODEL_MIN_OUTPUT_LENGTH,
# 'max_length': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
# 'max_new_tokens': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
# 'num_return_sequences': 1,
# 'return_full_text': False,
# # "repetition_penalty": 0.0001
# },
# 'options': {
# 'wait_for_model': True,
# 'use_cache': True
# }
# })
#
# output = output[0]['generated_text'].strip()
# # output = output[len(template_txt):]
#
# json_end_idx = output.rfind('```')
# if json_end_idx != -1:
# # logging.debug(f'{json_end_idx=}')
# output = output[:json_end_idx]
#
# logger.debug('generate_slides_content: output: %s', output)
#
# return output
if __name__ == '__main__':
# results = get_related_websites('5G AI WiFi 6')
#
# for a_result in results.results:
# print(a_result.title, a_result.url, a_result.extract)
# get_ai_image('A talk on AI, covering pros and cons')
pass