File size: 2,629 Bytes
65964b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from __future__ import annotations

import time

from typing import Any, Callable, List, Optional

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.huggingface_hub import HuggingFaceHub

import logging
logger = logging.getLogger(__name__)

from tenacity import (
    before_sleep_log,
    retry,
    retry_if_exception_type,
    stop_after_attempt,
    wait_exponential,
)


def _create_retry_decorator(llm: KronHuggingFaceHub) -> Callable[[Any], Any]:
    #import cohere

    min_seconds = 4
    max_seconds = 10
    # Wait 2^x * 1 second between each retry starting with
    # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
    return retry(
        reraise=True,
        stop=stop_after_attempt(llm.max_retries),
        wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
        retry=(retry_if_exception_type(KronHFHubRateExceededException)),
        before_sleep=before_sleep_log(logger, logging.WARNING),
    )


def completion_with_retry(llm: KronHuggingFaceHub, **kwargs: Any) -> Any:
    """Use tenacity to retry the completion call."""
    retry_decorator = _create_retry_decorator(llm)

    @retry_decorator
    def _completion_with_retry(**kwargs: Any) -> Any:
        return llm.internal_call(**kwargs)

    return _completion_with_retry(**kwargs)

class KronHFHubRateExceededException(Exception):
    def __init__(self, message="HF Hub Service Unavailable: Rate exceeded."):
        self.message = message
        super().__init__(self.message)


class KronHuggingFaceHub(HuggingFaceHub):

    max_retries: int = 10
    """Maximum number of retries to make when generating."""

    def internal_call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        try:
            print(f'**************************************\n{prompt}')
            response =  super()._call(prompt, stop, run_manager, **kwargs)
            print(f'**************************************\n{response}')
            return response
        except ValueError as ve:
            if "Service Unavailable" in str(ve):
                raise KronHFHubRateExceededException()
            else:
                raise ve

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        response = completion_with_retry(self, prompt=prompt, stop=stop, run_manager=run_manager, **kwargs)
        return response