|
from __future__ import annotations |
|
|
|
import time |
|
|
|
from typing import Any, Callable, List, Optional |
|
|
|
from langchain.callbacks.manager import CallbackManagerForLLMRun |
|
from langchain.llms.huggingface_hub import HuggingFaceHub |
|
|
|
import logging |
|
logger = logging.getLogger(__name__) |
|
|
|
from tenacity import ( |
|
before_sleep_log, |
|
retry, |
|
retry_if_exception_type, |
|
stop_after_attempt, |
|
wait_exponential, |
|
) |
|
|
|
|
|
def _create_retry_decorator(llm: KronHuggingFaceHub) -> Callable[[Any], Any]: |
|
|
|
|
|
min_seconds = 4 |
|
max_seconds = 10 |
|
|
|
|
|
return retry( |
|
reraise=True, |
|
stop=stop_after_attempt(llm.max_retries), |
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), |
|
retry=(retry_if_exception_type(KronHFHubRateExceededException)), |
|
before_sleep=before_sleep_log(logger, logging.WARNING), |
|
) |
|
|
|
|
|
def completion_with_retry(llm: KronHuggingFaceHub, **kwargs: Any) -> Any: |
|
"""Use tenacity to retry the completion call.""" |
|
retry_decorator = _create_retry_decorator(llm) |
|
|
|
@retry_decorator |
|
def _completion_with_retry(**kwargs: Any) -> Any: |
|
return llm.internal_call(**kwargs) |
|
|
|
return _completion_with_retry(**kwargs) |
|
|
|
class KronHFHubRateExceededException(Exception): |
|
def __init__(self, message="HF Hub Service Unavailable: Rate exceeded."): |
|
self.message = message |
|
super().__init__(self.message) |
|
|
|
|
|
class KronHuggingFaceHub(HuggingFaceHub): |
|
|
|
max_retries: int = 10 |
|
"""Maximum number of retries to make when generating.""" |
|
|
|
def internal_call( |
|
self, |
|
prompt: str, |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
**kwargs: Any, |
|
) -> str: |
|
try: |
|
print(f'**************************************\n{prompt}') |
|
response = super()._call(prompt, stop, run_manager, **kwargs) |
|
print(f'**************************************\n{response}') |
|
return response |
|
except ValueError as ve: |
|
if "Service Unavailable" in str(ve): |
|
raise KronHFHubRateExceededException() |
|
else: |
|
raise ve |
|
|
|
def _call( |
|
self, |
|
prompt: str, |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
**kwargs: Any, |
|
) -> str: |
|
response = completion_with_retry(self, prompt=prompt, stop=stop, run_manager=run_manager, **kwargs) |
|
return response |