Xudong Liu commited on
Commit
d987918
·
unverified ·
1 Parent(s): c51b92e

增加对Anthropic的Claude大模型的支持 (#919)

Browse files

* 增加了对Anthropic的Claude模型的支持

* 增加对Anthropic的Claude大模型的支持

config_example.json CHANGED
@@ -14,6 +14,7 @@
14
  "spark_appid": "", // 你的 讯飞星火大模型 API AppID,用于讯飞星火大模型对话模型
15
  "spark_api_key": "", // 你的 讯飞星火大模型 API Key,用于讯飞星火大模型对话模型
16
  "spark_api_secret": "", // 你的 讯飞星火大模型 API Secret,用于讯飞星火大模型对话模型
 
17
 
18
 
19
  //== Azure ==
 
14
  "spark_appid": "", // 你的 讯飞星火大模型 API AppID,用于讯飞星火大模型对话模型
15
  "spark_api_key": "", // 你的 讯飞星火大模型 API Key,用于讯飞星火大模型对话模型
16
  "spark_api_secret": "", // 你的 讯飞星火大模型 API Secret,用于讯飞星火大模型对话模型
17
+ "claude_api_secret":"",// 你的 Claude API Secret,用于 Claude 对话模型
18
 
19
 
20
  //== Azure ==
modules/config.py CHANGED
@@ -128,6 +128,9 @@ os.environ["SPARK_APPID"] = spark_appid
128
  spark_api_secret = config.get("spark_api_secret", "")
129
  os.environ["SPARK_API_SECRET"] = spark_api_secret
130
 
 
 
 
131
  load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
132
  "azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
133
 
 
128
  spark_api_secret = config.get("spark_api_secret", "")
129
  os.environ["SPARK_API_SECRET"] = spark_api_secret
130
 
131
+ claude_api_secret = config.get("claude_api_secret", "")
132
+ os.environ["CLAUDE_API_SECRET"] = claude_api_secret
133
+
134
  load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
135
  "azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
136
 
modules/models/Claude.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
3
+ from ..presets import *
4
+ from ..utils import *
5
+
6
+ from .base_model import BaseLLMModel
7
+
8
+
9
+ class Claude_Client(BaseLLMModel):
10
+ def __init__(self, model_name, api_secret) -> None:
11
+ super().__init__(model_name=model_name)
12
+ self.api_secret = api_secret
13
+ if None in [self.api_secret]:
14
+ raise Exception("请在配置文件或者环境变量中设置Claude的API Secret")
15
+ self.claude_client = Anthropic(api_key=self.api_secret)
16
+
17
+
18
+ def get_answer_stream_iter(self):
19
+ system_prompt = self.system_prompt
20
+ history = self.history
21
+ if system_prompt is not None:
22
+ history = [construct_system(system_prompt), *history]
23
+
24
+ completion = self.claude_client.completions.create(
25
+ model=self.model_name,
26
+ max_tokens_to_sample=300,
27
+ prompt=f"{HUMAN_PROMPT}{history}{AI_PROMPT}",
28
+ stream=True,
29
+ )
30
+ if completion is not None:
31
+ partial_text = ""
32
+ for chunk in completion:
33
+ partial_text += chunk.completion
34
+ yield partial_text
35
+ else:
36
+ yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
37
+
38
+
39
+ def get_answer_at_once(self):
40
+ system_prompt = self.system_prompt
41
+ history = self.history
42
+ if system_prompt is not None:
43
+ history = [construct_system(system_prompt), *history]
44
+
45
+ completion = self.claude_client.completions.create(
46
+ model=self.model_name,
47
+ max_tokens_to_sample=300,
48
+ prompt=f"{HUMAN_PROMPT}{history}{AI_PROMPT}",
49
+ )
50
+ if completion is not None:
51
+ return completion.completion, len(completion.completion)
52
+ else:
53
+ return "获取资源错误", 0
54
+
55
+
modules/models/base_model.py CHANGED
@@ -145,6 +145,7 @@ class ModelType(Enum):
145
  Midjourney = 11
146
  Spark = 12
147
  OpenAIInstruct = 13
 
148
 
149
  @classmethod
150
  def get_type(cls, model_name: str):
@@ -179,6 +180,8 @@ class ModelType(Enum):
179
  model_type = ModelType.LangchainChat
180
  elif "星火大模型" in model_name_lower:
181
  model_type = ModelType.Spark
 
 
182
  else:
183
  model_type = ModelType.LLaMA
184
  return model_type
 
145
  Midjourney = 11
146
  Spark = 12
147
  OpenAIInstruct = 13
148
+ Claude = 14
149
 
150
  @classmethod
151
  def get_type(cls, model_name: str):
 
180
  model_type = ModelType.LangchainChat
181
  elif "星火大模型" in model_name_lower:
182
  model_type = ModelType.Spark
183
+ elif "claude" in model_name_lower:
184
+ model_type = ModelType.Claude
185
  else:
186
  model_type = ModelType.LLaMA
187
  return model_type
modules/models/models.py CHANGED
@@ -116,6 +116,9 @@ def get_model(
116
  from .spark import Spark_Client
117
  model = Spark_Client(model_name, os.getenv("SPARK_APPID"), os.getenv(
118
  "SPARK_API_KEY"), os.getenv("SPARK_API_SECRET"), user_name=user_name)
 
 
 
119
  elif model_type == ModelType.Unknown:
120
  raise ValueError(f"未知模型: {model_name}")
121
  logging.info(msg)
 
116
  from .spark import Spark_Client
117
  model = Spark_Client(model_name, os.getenv("SPARK_APPID"), os.getenv(
118
  "SPARK_API_KEY"), os.getenv("SPARK_API_SECRET"), user_name=user_name)
119
+ elif model_type == ModelType.Claude:
120
+ from .Claude import Claude_Client
121
+ model = Claude_Client(model_name="claude-2", api_secret=os.getenv("CLAUDE_API_SECRET"))
122
  elif model_type == ModelType.Unknown:
123
  raise ValueError(f"未知模型: {model_name}")
124
  logging.info(msg)
modules/presets.py CHANGED
@@ -74,7 +74,8 @@ ONLINE_MODELS = [
74
  "minimax-abab5-chat",
75
  "midjourney",
76
  "讯飞星火大模型V2.0",
77
- "讯飞星火大模型V1.5"
 
78
  ]
79
 
80
  LOCAL_MODELS = [
@@ -125,7 +126,8 @@ MODEL_TOKEN_LIMIT = {
125
  "gpt-4-0613": 8192,
126
  "gpt-4-32k": 32768,
127
  "gpt-4-32k-0314": 32768,
128
- "gpt-4-32k-0613": 32768
 
129
  }
130
 
131
  TOKEN_OFFSET = 1000 # 模型的token上限减去这个值,得到软上限。到达软上限之后,自动尝试减少token占用。
 
74
  "minimax-abab5-chat",
75
  "midjourney",
76
  "讯飞星火大模型V2.0",
77
+ "讯飞星火大模型V1.5",
78
+ "Claude"
79
  ]
80
 
81
  LOCAL_MODELS = [
 
126
  "gpt-4-0613": 8192,
127
  "gpt-4-32k": 32768,
128
  "gpt-4-32k-0314": 32768,
129
+ "gpt-4-32k-0613": 32768,
130
+ "Claude": 4096
131
  }
132
 
133
  TOKEN_OFFSET = 1000 # 模型的token上限减去这个值,得到软上限。到达软上限之后,自动尝试减少token占用。
requirements.txt CHANGED
@@ -30,3 +30,5 @@ python-docx
30
  websocket_client
31
  pydantic==1.10.8
32
  google-search-results
 
 
 
30
  websocket_client
31
  pydantic==1.10.8
32
  google-search-results
33
+ anthropic==0.3.11
34
+