Spaces:
Runtime error
Runtime error
File size: 8,955 Bytes
b994311 3871450 b994311 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
import base64
import hmac
import json
from datetime import datetime, timezone
from urllib.parse import urlencode, urlparse
from websocket import create_connection, WebSocketConnectionClosedException
from utils.tools import get_prompt, process_response, init_script, create_script
class SparkAPI:
__api_url = 'wss://spark-api.xf-yun.com/v2.1/chat'
__max_token = 4096
def __init__(self, app_id, api_key, api_secret):
self.__app_id = app_id
self.__api_key = api_key
self.__api_secret = api_secret
def __set_max_tokens(self, token):
if isinstance(token, int) is False or token < 0:
print("set_max_tokens() error: tokens should be a positive integer!")
return
self.__max_token = token
def __get_authorization_url(self):
authorize_url = urlparse(self.__api_url)
# 1. generate data
date = datetime.now(timezone.utc).strftime('%a, %d %b %Y %H:%M:%S %Z')
"""
Generation rule of Authorization parameters
1) Obtain the APIKey and APISecret parameters from the console.
2) Use the aforementioned date to dynamically concatenate a string tmp. Here we take Huobi's URL as an example,
the actual usage requires replacing the host and path with the specific request URL.
"""
signature_origin = "host: {}\ndate: {}\nGET {} HTTP/1.1".format(
authorize_url.netloc, date, authorize_url.path
)
signature = base64.b64encode(
hmac.new(
self.__api_secret.encode(),
signature_origin.encode(),
digestmod='sha256'
).digest()
).decode()
authorization_origin = \
'api_key="{}",algorithm="{}",headers="{}",signature="{}"'.format(
self.__api_key, "hmac-sha256", "host date request-line", signature
)
authorization = base64.b64encode(
authorization_origin.encode()).decode()
params = {
"authorization": authorization,
"date": date,
"host": authorize_url.netloc
}
ws_url = self.__api_url + "?" + urlencode(params)
return ws_url
def __build_inputs(
self,
message: dict,
user_id: str = "001",
domain: str = "general",
temperature: float = 0.5,
max_tokens: int = 4096
):
input_dict = {
"header": {
"app_id": self.__app_id,
"uid": user_id,
},
"parameter": {
"chat": {
"domain": domain,
"temperature": temperature,
"max_tokens": max_tokens,
}
},
"payload": {
"message": message
}
}
return json.dumps(input_dict)
def chat(
self,
query: str,
history: list = None, # store the conversation history
user_id: str = "001",
domain: str = "general",
max_tokens: int = 4096,
temperature: float = 0.5,
):
if history is None:
history = []
# the max of max_length is 4096
max_tokens = min(max_tokens, 4096)
url = self.__get_authorization_url()
ws = create_connection(url)
message = get_prompt(query, history)
input_str = self.__build_inputs(
message=message,
user_id=user_id,
domain=domain,
temperature=temperature,
max_tokens=max_tokens,
)
ws.send(input_str)
response_str = ws.recv()
try:
while True:
response, history, status = process_response(
response_str, history)
"""
The final return result, which means a complete conversation.
doc url: https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
"""
if len(response) == 0 or status == 2:
break
response_str = ws.recv()
return response
except WebSocketConnectionClosedException:
print("Connection closed")
finally:
ws.close()
# Stream output statement, used for terminal chat.
def streaming_output(
self,
query: str,
history: list = None, # store the conversation history
user_id: str = "001",
domain: str = "general",
max_tokens: int = 4096,
temperature: float = 0.5,
):
if history is None:
history = []
# the max of max_length is 4096
max_tokens = min(max_tokens, 4096)
url = self.__get_authorization_url()
ws = create_connection(url)
message = get_prompt(query, history)
input_str = self.__build_inputs(
message=message,
user_id=user_id,
domain=domain,
temperature=temperature,
max_tokens=max_tokens,
)
# print(input_str)
# send question or prompt to url, and receive the answer
ws.send(input_str)
response_str = ws.recv()
# Continuous conversation
try:
while True:
response, history, status = process_response(
response_str, history)
yield response, history
if len(response) == 0 or status == 2:
break
response_str = ws.recv()
except WebSocketConnectionClosedException:
print("Connection closed")
finally:
ws.close()
def chat_stream(self):
history = []
try:
print("输入init来初始化剧本,输入create来创作剧本,输入exit或stop来终止对话\n")
while True:
query = input("Ask: ")
if query == 'init':
jsonfile = input("请输入剧本文件路径:")
script_data = init_script(history, jsonfile)
print(
f"正在导入剧本{script_data['name']},角色信息:{script_data['characters']},剧情介绍:{script_data['summary']}")
query = f"我希望你能够扮演这个剧本杀游戏的主持人,我希望你能够逐步引导玩家到达最终结局,同时希望你在游戏中设定一些随机事件,需要玩家依靠自身的能力解决,当玩家做出偏离主线的行为或者与剧本无关的行为时,你需要委婉地将玩家引导至正常游玩路线中,对于玩家需要决策的事件,你需要提供一些行动推荐,下面是剧本介绍:{script_data}"
if query == 'create':
name = input('请输入剧本名称:')
characters = input('请输入角色信息:')
summary = input('请输入剧情介绍:')
details = input('请输入剧本细节')
create_script(name, characters, summary, details)
print('剧本创建成功!')
continue
if query == "exit" or query == "stop":
break
for response, _ in self.streaming_output(query, history):
print("\r" + response, end="")
print("\n")
finally:
print("\nThank you for using the SparkDesk AI. Welcome to use it again!")
from langchain.llms.base import LLM
from typing import Any, List, Mapping, Optional
class Spark_forlangchain(LLM):
# 类的成员变量,类型为整型
n: int
app_id: str
api_key: str
api_secret: str
# 用于指定该子类对象的类型
@property
def _llm_type(self) -> str:
return "Spark"
# 重写基类方法,根据用户输入的prompt来响应用户,返回字符串
def _call(
self,
query: str,
history: list = None, # store the conversation history
user_id: str = "001",
domain: str = "general",
max_tokens: int = 4096,
temperature: float = 0.7,
stop: Optional[List[str]] = None,
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
bot = SparkAPI(app_id=self.app_id, api_key=self.api_key,
api_secret=self.api_secret)
response = bot.chat(query, history, user_id,
domain, max_tokens, temperature)
return response
# 返回一个字典类型,包含LLM的唯一标识
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"n": self.n} |