mlk8s / kron /llm_predictor /KronHFHubLLM.py
Arylwen's picture
0.1.1 refactor and ui changes
65964b2
raw
history blame
2.63 kB
from __future__ import annotations
import time
from typing import Any, Callable, List, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.huggingface_hub import HuggingFaceHub
import logging
logger = logging.getLogger(__name__)
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
def _create_retry_decorator(llm: KronHuggingFaceHub) -> Callable[[Any], Any]:
#import cohere
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(llm.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(retry_if_exception_type(KronHFHubRateExceededException)),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def completion_with_retry(llm: KronHuggingFaceHub, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return llm.internal_call(**kwargs)
return _completion_with_retry(**kwargs)
class KronHFHubRateExceededException(Exception):
def __init__(self, message="HF Hub Service Unavailable: Rate exceeded."):
self.message = message
super().__init__(self.message)
class KronHuggingFaceHub(HuggingFaceHub):
max_retries: int = 10
"""Maximum number of retries to make when generating."""
def internal_call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
try:
print(f'**************************************\n{prompt}')
response = super()._call(prompt, stop, run_manager, **kwargs)
print(f'**************************************\n{response}')
return response
except ValueError as ve:
if "Service Unavailable" in str(ve):
raise KronHFHubRateExceededException()
else:
raise ve
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
response = completion_with_retry(self, prompt=prompt, stop=stop, run_manager=run_manager, **kwargs)
return response