Spaces:
Sleeping
Sleeping
File size: 7,260 Bytes
79340f2 14a6ed5 79340f2 14a6ed5 79340f2 6843f1c 79340f2 14a6ed5 79340f2 6843f1c 79340f2 14a6ed5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
from hugchat import hugchat
from hugchat.login import Login
import time
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
# THIS IS A CUSTOM LLM WRAPPER Based on hugchat library
# Reference :
# - Langchain custom LLM wrapper : https://python.langchain.com/docs/modules/model_io/models/llms/how_to/custom_llm
# - HugChat library : https://github.com/Soulter/hugging-chat-api
# - I am Alessandro Ciciarelli the owner of IntelligenzaArtificialeItalia.net , my dream is to democratize AI and make it accessible to everyone.
class Login(Login):
def _get_auth_url(self):
url = "https://huggingface.co/chat/login"
headers = {
"Referer": "https://huggingface.co/chat/login",
"User-Agent": self.headers["User-Agent"],
"Content-Type": "application/x-www-form-urlencoded",
"Origin": "https://huggingface.co/chat"
}
res = self._request_post(url, headers=headers, allow_redirects=False)
print(res.status_code, res.headers, res.text, "\n ----------------------------------------------------------------------")
if res.status_code == 200:
# location = res.headers.get("Location", None)
location = res.json()["location"]
if location:
return location
else:
raise Exception(
"No authorize url found, please check your email or password.")
elif res.status_code == 303:
location = res.headers.get("Location")
if location:
return location
else:
raise Exception(
"No authorize url found, please check your email or password.")
else:
raise Exception("Something went wrong!")
class HuggingChat(LLM):
"""HuggingChat LLM wrapper."""
chatbot : Optional[hugchat.ChatBot] = None
email: Optional[str] = None
psw: Optional[str] = None
cookie_path : Optional[str] = None
conversation : Optional[str] = None
model: Optional[int] = 0 # 0 = OpenAssistant/oasst-sft-6-llama-30b-xor , 1 = meta-llama/Llama-2-70b-chat-hf
temperature: Optional[float] = 0.9
top_p: Optional[float] = 0.95
repetition_penalty: Optional[float] = 1.2
top_k: Optional[int] = 50
truncate: Optional[int] = 1024
watermark: Optional[bool] = False
max_new_tokens: Optional[int] = 1024
stop: Optional[list] = ["</s>"]
return_full_text: Optional[bool] = False
stream_resp: Optional[bool] = True
use_cache: Optional[bool] = False
is_retry: Optional[bool] = False
retry_count: Optional[int] = 5
avg_response_time: float = 0.0
log : Optional[bool] = False
@property
def _llm_type(self) -> str:
return "🤗CUSTOM LLM WRAPPER Based on hugging-chat-api library"
def create_chatbot(self) -> None:
if not any([self.email, self.psw, self.cookie_path]):
raise ValueError("email, psw, or cookie_path is required.")
try:
if self.email and self.psw:
# Create a ChatBot using email and psw
start_time = time.time()
sign = Login(self.email, self.psw)
cookies = sign.login()
end_time = time.time()
if self.log : print(f"\n[LOG] Login successfull in {round(end_time - start_time)} seconds")
else:
# Create a ChatBot using cookie_path
cookies = self.cookie_path and hugchat.ChatBot(cookie_path=self.cookie_path)
self.chatbot = cookies.get_dict() and hugchat.ChatBot(cookies=cookies.get_dict())
if self.log : print(f"[LOG] LLM WRAPPER created successfully")
except Exception as e:
raise ValueError("LogIn failed. Please check your credentials or cookie_path. " + str(e))
# Setup ChatBot info
self.chatbot.switch_llm(self.model)
if self.log : print(f"[LOG] LLM WRAPPER switched to model { 'OpenAssistant/oasst-sft-6-llama-30b-xor' if self.model == 0 else 'meta-llama/Llama-2-70b-chat-hf'}")
self.conversation = self.conversation or self.chatbot.new_conversation()
self.chatbot.change_conversation(self.conversation)
if self.log : print(f"[LOG] LLM WRAPPER changed conversation to {self.conversation}\n")
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if stop:
raise ValueError("stop kwargs are not permitted.")
self.create_chatbot() if not self.chatbot else None
try:
if self.log : print(f"[LOG] LLM WRAPPER called with prompt: {prompt}")
start_time = time.time()
resp = self.chatbot.chat(
prompt,
temperature=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
top_k=self.top_k,
truncate=self.truncate,
watermark=self.watermark,
max_new_tokens=self.max_new_tokens,
stop=self.stop,
return_full_text=self.return_full_text,
stream=self.stream_resp,
use_cache=self.use_cache,
is_retry=self.is_retry,
retry_count=self.retry_count,
)
end_time = time.time()
self.avg_response_time = (self.avg_response_time + (end_time - start_time)) / 2 if self.avg_response_time else end_time - start_time
if self.log : print(f"[LOG] LLM WRAPPER response time: {round(end_time - start_time)} seconds")
if self.log : print(f"[LOG] LLM WRAPPER avg response time: {round(self.avg_response_time)} seconds")
if self.log : print(f"[LOG] LLM WRAPPER response: {resp}\n\n")
return str(resp)
except Exception as e:
raise ValueError("ChatBot failed, please check your parameters. " + str(e))
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
parms = {
"model": "HuggingChat",
"temperature": self.temperature,
"top_p": self.top_p,
"repetition_penalty": self.repetition_penalty,
"top_k": self.top_k,
"truncate": self.truncate,
"watermark": self.watermark,
"max_new_tokens": self.max_new_tokens,
"stop": self.stop,
"return_full_text": self.return_full_text,
"stream": self.stream_resp,
"use_cache": self.use_cache,
"is_retry": self.is_retry,
"retry_count": self.retry_count,
"avg_response_time": self.avg_response_time,
}
return parms
@property
def _get_avg_response_time(self) -> float:
"""Get the average response time."""
return self.avg_response_time
|