File size: 7,260 Bytes
79340f2
14a6ed5
79340f2
 
 
 
 
 
 
 
 
 
 
 
 
 
14a6ed5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79340f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6843f1c
79340f2
14a6ed5
79340f2
 
 
 
 
6843f1c
 
 
79340f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14a6ed5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from hugchat import hugchat
from hugchat.login import Login
import time 

from typing import Any, List, Mapping, Optional

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM


# THIS IS A CUSTOM LLM WRAPPER Based on hugchat library
# Reference :
# - Langchain custom LLM wrapper : https://python.langchain.com/docs/modules/model_io/models/llms/how_to/custom_llm
# - HugChat library : https://github.com/Soulter/hugging-chat-api
# - I am Alessandro Ciciarelli the owner of IntelligenzaArtificialeItalia.net , my dream is to democratize AI and make it accessible to everyone.

class Login(Login):
    def _get_auth_url(self):
        url = "https://huggingface.co/chat/login"
        headers = {
            "Referer": "https://huggingface.co/chat/login",
            "User-Agent": self.headers["User-Agent"],
            "Content-Type": "application/x-www-form-urlencoded",
            "Origin": "https://huggingface.co/chat"
        }
        res = self._request_post(url, headers=headers, allow_redirects=False)
        print(res.status_code, res.headers, res.text, "\n ----------------------------------------------------------------------")
        if res.status_code == 200:
            # location = res.headers.get("Location", None)
            location = res.json()["location"]
            if location:
                return location
            else:
                raise Exception(
                    "No authorize url found, please check your email or password.")
        elif res.status_code == 303:
            location = res.headers.get("Location")
            if location:
                return location
            else:
                raise Exception(
                    "No authorize url found, please check your email or password.")
        else:
            raise Exception("Something went wrong!")

class HuggingChat(LLM):

    """HuggingChat LLM wrapper."""

    chatbot : Optional[hugchat.ChatBot] = None


    email: Optional[str] = None
    psw: Optional[str] = None
    cookie_path : Optional[str] = None

    conversation : Optional[str] = None
    model: Optional[int] = 0 # 0 = OpenAssistant/oasst-sft-6-llama-30b-xor , 1 = meta-llama/Llama-2-70b-chat-hf

    temperature: Optional[float] = 0.9
    top_p: Optional[float] = 0.95
    repetition_penalty: Optional[float] = 1.2
    top_k: Optional[int] = 50
    truncate: Optional[int] = 1024
    watermark: Optional[bool] = False
    max_new_tokens: Optional[int] = 1024
    stop: Optional[list] = ["</s>"]
    return_full_text: Optional[bool] = False
    stream_resp: Optional[bool] = True
    use_cache: Optional[bool] = False
    is_retry: Optional[bool] = False
    retry_count: Optional[int] = 5

    avg_response_time: float = 0.0
    log : Optional[bool] = False
    
    
    @property
    def _llm_type(self) -> str:
        return "🤗CUSTOM LLM WRAPPER Based on hugging-chat-api library"
    

    def create_chatbot(self) -> None:
        if not any([self.email, self.psw, self.cookie_path]):
            raise ValueError("email, psw, or cookie_path is required.")
        
        try:
            if self.email and self.psw:
                # Create a ChatBot using email and psw
                
                start_time = time.time()
                sign = Login(self.email, self.psw)
                cookies = sign.login()
                end_time = time.time()
                if self.log : print(f"\n[LOG] Login successfull in {round(end_time - start_time)} seconds")
            else:
                # Create a ChatBot using cookie_path
                cookies = self.cookie_path and hugchat.ChatBot(cookie_path=self.cookie_path)
            
            self.chatbot = cookies.get_dict() and hugchat.ChatBot(cookies=cookies.get_dict())
            if self.log : print(f"[LOG] LLM WRAPPER created successfully")
            
        except Exception as e:
            raise ValueError("LogIn failed. Please check your credentials or cookie_path. " + str(e))

        # Setup ChatBot info
        self.chatbot.switch_llm(self.model)
        if self.log : print(f"[LOG] LLM WRAPPER switched to model { 'OpenAssistant/oasst-sft-6-llama-30b-xor' if self.model == 0 else 'meta-llama/Llama-2-70b-chat-hf'}")

        self.conversation = self.conversation or self.chatbot.new_conversation()
        self.chatbot.change_conversation(self.conversation)
        if self.log : print(f"[LOG] LLM WRAPPER changed conversation to {self.conversation}\n")
        


    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        if stop:
            raise ValueError("stop kwargs are not permitted.")
        
        self.create_chatbot() if not self.chatbot else None
        
        try:
            if self.log : print(f"[LOG] LLM WRAPPER called with prompt: {prompt}")
            start_time = time.time()
            resp = self.chatbot.chat(
                prompt,
                temperature=self.temperature,
                top_p=self.top_p,
                repetition_penalty=self.repetition_penalty,
                top_k=self.top_k,
                truncate=self.truncate,
                watermark=self.watermark,
                max_new_tokens=self.max_new_tokens,
                stop=self.stop,
                return_full_text=self.return_full_text,
                stream=self.stream_resp,
                use_cache=self.use_cache,
                is_retry=self.is_retry,
                retry_count=self.retry_count,
            )

            end_time = time.time()
            
            self.avg_response_time = (self.avg_response_time + (end_time - start_time)) / 2 if self.avg_response_time else end_time - start_time
            
            if self.log : print(f"[LOG] LLM WRAPPER response time: {round(end_time - start_time)} seconds")
            if self.log : print(f"[LOG] LLM WRAPPER avg response time: {round(self.avg_response_time)} seconds")
            if self.log : print(f"[LOG] LLM WRAPPER response: {resp}\n\n")

            return str(resp)
            
        except Exception as e:
            raise ValueError("ChatBot failed, please check your parameters. " + str(e))

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        parms = { 
            "model": "HuggingChat",
            "temperature": self.temperature,
            "top_p": self.top_p,
            "repetition_penalty": self.repetition_penalty,
            "top_k": self.top_k,
            "truncate": self.truncate,
            "watermark": self.watermark,
            "max_new_tokens": self.max_new_tokens,
            "stop": self.stop,
            "return_full_text": self.return_full_text,
            "stream": self.stream_resp,
            "use_cache": self.use_cache,
            "is_retry": self.is_retry,
            "retry_count": self.retry_count,
            "avg_response_time": self.avg_response_time,
        }
        return parms
    
    @property
    def _get_avg_response_time(self) -> float:
        """Get the average response time."""
        return self.avg_response_time