|
from __future__ import annotations |
|
|
|
import json |
|
import logging |
|
import traceback |
|
|
|
import colorama |
|
import requests |
|
|
|
from .. import shared |
|
from ..config import retrieve_proxy, sensitive_id, usage_limit |
|
from ..index_func import * |
|
from ..presets import * |
|
from ..utils import * |
|
from .base_model import BaseLLMModel |
|
|
|
|
|
class OpenAIClient(BaseLLMModel): |
|
def __init__( |
|
self, |
|
model_name, |
|
api_key, |
|
system_prompt=INITIAL_SYSTEM_PROMPT, |
|
temperature=1.0, |
|
top_p=1.0, |
|
user_name="" |
|
) -> None: |
|
super().__init__( |
|
model_name=MODEL_METADATA[model_name]["model_name"], |
|
temperature=temperature, |
|
top_p=top_p, |
|
system_prompt=system_prompt, |
|
user=user_name |
|
) |
|
self.api_key = api_key |
|
self.need_api_key = True |
|
self._refresh_header() |
|
|
|
def get_answer_stream_iter(self): |
|
response = self._get_response(stream=True) |
|
if response is not None: |
|
iter = self._decode_chat_response(response) |
|
partial_text = "" |
|
for i in iter: |
|
partial_text += i |
|
yield partial_text |
|
else: |
|
yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG |
|
|
|
def get_answer_at_once(self): |
|
response = self._get_response() |
|
response = json.loads(response.text) |
|
content = response["choices"][0]["message"]["content"] |
|
total_token_count = response["usage"]["total_tokens"] |
|
return content, total_token_count |
|
|
|
def count_token(self, user_input): |
|
input_token_count = count_token(construct_user(user_input)) |
|
if self.system_prompt is not None and len(self.all_token_counts) == 0: |
|
system_prompt_token_count = count_token( |
|
construct_system(self.system_prompt) |
|
) |
|
return input_token_count + system_prompt_token_count |
|
return input_token_count |
|
|
|
def billing_info(self): |
|
try: |
|
curr_time = datetime.datetime.now() |
|
last_day_of_month = get_last_day_of_month( |
|
curr_time).strftime("%Y-%m-%d") |
|
first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d") |
|
usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}" |
|
try: |
|
usage_data = self._get_billing_data(usage_url) |
|
except Exception as e: |
|
|
|
if "Invalid authorization header" in str(e): |
|
return i18n("**获取API使用情况失败**,需在填写`config.json`中正确填写sensitive_id") |
|
elif "Incorrect API key provided: sess" in str(e): |
|
return i18n("**获取API使用情况失败**,sensitive_id错误或已过期") |
|
return i18n("**获取API使用情况失败**") |
|
|
|
rounded_usage = round(usage_data["total_usage"] / 100, 5) |
|
usage_percent = round(usage_data["total_usage"] / usage_limit, 2) |
|
from ..webui import get_html |
|
|
|
|
|
return get_html("billing_info.html").format( |
|
label = i18n("本月使用金额"), |
|
usage_percent = usage_percent, |
|
rounded_usage = rounded_usage, |
|
usage_limit = usage_limit |
|
) |
|
except requests.exceptions.ConnectTimeout: |
|
status_text = ( |
|
STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG |
|
) |
|
return status_text |
|
except requests.exceptions.ReadTimeout: |
|
status_text = STANDARD_ERROR_MSG + READ_TIMEOUT_MSG + ERROR_RETRIEVE_MSG |
|
return status_text |
|
except Exception as e: |
|
import traceback |
|
traceback.print_exc() |
|
logging.error(i18n("获取API使用情况失败:") + str(e)) |
|
return STANDARD_ERROR_MSG + ERROR_RETRIEVE_MSG |
|
|
|
def set_token_upper_limit(self, new_upper_limit): |
|
pass |
|
|
|
@shared.state.switching_api_key |
|
def _get_response(self, stream=False): |
|
openai_api_key = self.api_key |
|
system_prompt = self.system_prompt |
|
history = self.history |
|
logging.debug(colorama.Fore.YELLOW + |
|
f"{history}" + colorama.Fore.RESET) |
|
headers = { |
|
"Content-Type": "application/json", |
|
"Authorization": f"Bearer {openai_api_key}", |
|
} |
|
|
|
if system_prompt is not None: |
|
history = [construct_system(system_prompt), *history] |
|
|
|
payload = { |
|
"model": self.model_name, |
|
"messages": history, |
|
"temperature": self.temperature, |
|
"top_p": self.top_p, |
|
"n": self.n_choices, |
|
"stream": stream, |
|
"presence_penalty": self.presence_penalty, |
|
"frequency_penalty": self.frequency_penalty, |
|
} |
|
|
|
if self.max_generation_token is not None: |
|
payload["max_tokens"] = self.max_generation_token |
|
if self.stop_sequence is not None: |
|
payload["stop"] = self.stop_sequence |
|
if self.logit_bias is not None: |
|
payload["logit_bias"] = self.logit_bias |
|
if self.user_identifier: |
|
payload["user"] = self.user_identifier |
|
|
|
if stream: |
|
timeout = TIMEOUT_STREAMING |
|
else: |
|
timeout = TIMEOUT_ALL |
|
|
|
|
|
if shared.state.chat_completion_url != CHAT_COMPLETION_URL: |
|
logging.debug(f"使用自定义API URL: {shared.state.chat_completion_url}") |
|
|
|
with retrieve_proxy(): |
|
try: |
|
response = requests.post( |
|
shared.state.chat_completion_url, |
|
headers=headers, |
|
json=payload, |
|
stream=stream, |
|
timeout=timeout, |
|
) |
|
except: |
|
traceback.print_exc() |
|
return None |
|
return response |
|
|
|
def _refresh_header(self): |
|
self.headers = { |
|
"Content-Type": "application/json", |
|
"Authorization": f"Bearer {sensitive_id}", |
|
} |
|
|
|
|
|
def _get_billing_data(self, billing_url): |
|
with retrieve_proxy(): |
|
response = requests.get( |
|
billing_url, |
|
headers=self.headers, |
|
timeout=TIMEOUT_ALL, |
|
) |
|
|
|
if response.status_code == 200: |
|
data = response.json() |
|
return data |
|
else: |
|
raise Exception( |
|
f"API request failed with status code {response.status_code}: {response.text}" |
|
) |
|
|
|
def _decode_chat_response(self, response): |
|
error_msg = "" |
|
for chunk in response.iter_lines(): |
|
if chunk: |
|
chunk = chunk.decode() |
|
chunk_length = len(chunk) |
|
try: |
|
chunk = json.loads(chunk[6:]) |
|
except: |
|
print(i18n("JSON解析错误,收到的内容: ") + f"{chunk}") |
|
error_msg += chunk |
|
continue |
|
try: |
|
if chunk_length > 6 and "delta" in chunk["choices"][0]: |
|
if "finish_reason" in chunk["choices"][0]: |
|
finish_reason = chunk["choices"][0]["finish_reason"] |
|
else: |
|
finish_reason = chunk["finish_reason"] |
|
if finish_reason == "stop": |
|
break |
|
try: |
|
yield chunk["choices"][0]["delta"]["content"] |
|
except Exception as e: |
|
|
|
continue |
|
except: |
|
print(f"ERROR: {chunk}") |
|
continue |
|
if error_msg and not error_msg=="data: [DONE]": |
|
raise Exception(error_msg) |
|
|
|
def set_key(self, new_access_key): |
|
ret = super().set_key(new_access_key) |
|
self._refresh_header() |
|
return ret |
|
|
|
def _single_query_at_once(self, history, temperature=1.0): |
|
timeout = TIMEOUT_ALL |
|
headers = { |
|
"Content-Type": "application/json", |
|
"Authorization": f"Bearer {self.api_key}", |
|
"temperature": f"{temperature}", |
|
} |
|
payload = { |
|
"model": self.model_name, |
|
"messages": history, |
|
} |
|
|
|
if shared.state.chat_completion_url != CHAT_COMPLETION_URL: |
|
logging.debug(f"使用自定义API URL: {shared.state.chat_completion_url}") |
|
|
|
with retrieve_proxy(): |
|
response = requests.post( |
|
shared.state.chat_completion_url, |
|
headers=headers, |
|
json=payload, |
|
stream=False, |
|
timeout=timeout, |
|
) |
|
|
|
return response |
|
|
|
|
|
def auto_name_chat_history(self, name_chat_method, user_question, chatbot, user_name, single_turn_checkbox): |
|
if len(self.history) == 2 and not single_turn_checkbox and not hide_history_when_not_logged_in: |
|
user_question = self.history[0]["content"] |
|
if name_chat_method == i18n("模型自动总结(消耗tokens)"): |
|
ai_answer = self.history[1]["content"] |
|
try: |
|
history = [ |
|
{ "role": "system", "content": SUMMARY_CHAT_SYSTEM_PROMPT}, |
|
{ "role": "user", "content": f"Please write a title based on the following conversation:\n---\nUser: {user_question}\nAssistant: {ai_answer}"} |
|
] |
|
response = self._single_query_at_once(history, temperature=0.0) |
|
response = json.loads(response.text) |
|
content = response["choices"][0]["message"]["content"] |
|
filename = replace_special_symbols(content) + ".json" |
|
except Exception as e: |
|
logging.info(f"自动命名失败。{e}") |
|
filename = replace_special_symbols(user_question)[:16] + ".json" |
|
return self.rename_chat_history(filename, chatbot, user_name) |
|
elif name_chat_method == i18n("第一条提问"): |
|
filename = replace_special_symbols(user_question)[:16] + ".json" |
|
return self.rename_chat_history(filename, chatbot, user_name) |
|
else: |
|
return gr.update() |
|
else: |
|
return gr.update() |
|
|