JobSmithManipulation Kevin Hu commited on
Commit
fa680e0
·
1 Parent(s): 6d81859

support api-version and change default-model in adding azure-openai and openai (#2799)

Browse files

### What problem does this PR solve?
#2701 #2712 #2749

### Type of change
-[x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <[email protected]>

api/apps/llm_app.py CHANGED
@@ -58,7 +58,7 @@ def set_api_key():
58
  chat_passed, embd_passed, rerank_passed = False, False, False
59
  factory = req["llm_factory"]
60
  msg = ""
61
- for llm in LLMService.query(fid=factory)[:3]:
62
  if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
63
  mdl = EmbeddingModel[factory](
64
  req["api_key"], llm.llm_name, base_url=req.get("base_url"))
@@ -77,10 +77,10 @@ def set_api_key():
77
  {"temperature": 0.9,'max_tokens':50})
78
  if m.find("**ERROR**") >=0:
79
  raise Exception(m)
 
80
  except Exception as e:
81
  msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
82
  e)
83
- chat_passed = True
84
  elif not rerank_passed and llm.model_type == LLMType.RERANK:
85
  mdl = RerankModel[factory](
86
  req["api_key"], llm.llm_name, base_url=req.get("base_url"))
@@ -88,10 +88,14 @@ def set_api_key():
88
  arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
89
  if len(arr) == 0 or tc == 0:
90
  raise Exception("Fail")
 
 
91
  except Exception as e:
92
  msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
93
  e)
94
- rerank_passed = True
 
 
95
 
96
  if msg:
97
  return get_data_error_result(retmsg=msg)
@@ -183,6 +187,10 @@ def add_llm():
183
  llm_name = req["llm_name"]
184
  api_key = apikey_json(["google_project_id", "google_region", "google_service_account_key"])
185
 
 
 
 
 
186
  else:
187
  llm_name = req["llm_name"]
188
  api_key = req.get("api_key", "xxxxxxxxxxxxxxx")
 
58
  chat_passed, embd_passed, rerank_passed = False, False, False
59
  factory = req["llm_factory"]
60
  msg = ""
61
+ for llm in LLMService.query(fid=factory):
62
  if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
63
  mdl = EmbeddingModel[factory](
64
  req["api_key"], llm.llm_name, base_url=req.get("base_url"))
 
77
  {"temperature": 0.9,'max_tokens':50})
78
  if m.find("**ERROR**") >=0:
79
  raise Exception(m)
80
+ chat_passed = True
81
  except Exception as e:
82
  msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
83
  e)
 
84
  elif not rerank_passed and llm.model_type == LLMType.RERANK:
85
  mdl = RerankModel[factory](
86
  req["api_key"], llm.llm_name, base_url=req.get("base_url"))
 
88
  arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
89
  if len(arr) == 0 or tc == 0:
90
  raise Exception("Fail")
91
+ rerank_passed = True
92
+ print(f'passed model rerank{llm.llm_name}',flush=True)
93
  except Exception as e:
94
  msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
95
  e)
96
+ if any([embd_passed, chat_passed, rerank_passed]):
97
+ msg = ''
98
+ break
99
 
100
  if msg:
101
  return get_data_error_result(retmsg=msg)
 
187
  llm_name = req["llm_name"]
188
  api_key = apikey_json(["google_project_id", "google_region", "google_service_account_key"])
189
 
190
+ elif factory == "Azure-OpenAI":
191
+ llm_name = req["llm_name"]
192
+ api_key = apikey_json(["api_key", "api_version"])
193
+
194
  else:
195
  llm_name = req["llm_name"]
196
  api_key = req.get("api_key", "xxxxxxxxxxxxxxx")
conf/llm_factories.json CHANGED
@@ -619,13 +619,13 @@
619
  "model_type": "chat,image2text"
620
  },
621
  {
622
- "llm_name": "gpt-35-turbo",
623
  "tags": "LLM,CHAT,4K",
624
  "max_tokens": 4096,
625
  "model_type": "chat"
626
  },
627
  {
628
- "llm_name": "gpt-35-turbo-16k",
629
  "tags": "LLM,CHAT,16k",
630
  "max_tokens": 16385,
631
  "model_type": "chat"
 
619
  "model_type": "chat,image2text"
620
  },
621
  {
622
+ "llm_name": "gpt-3.5-turbo",
623
  "tags": "LLM,CHAT,4K",
624
  "max_tokens": 4096,
625
  "model_type": "chat"
626
  },
627
  {
628
+ "llm_name": "gpt-3.5-turbo-16k",
629
  "tags": "LLM,CHAT,16k",
630
  "max_tokens": 16385,
631
  "model_type": "chat"
rag/llm/chat_model.py CHANGED
@@ -114,7 +114,9 @@ class DeepSeekChat(Base):
114
 
115
  class AzureChat(Base):
116
  def __init__(self, key, model_name, **kwargs):
117
- self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
 
 
118
  self.model_name = model_name
119
 
120
 
 
114
 
115
  class AzureChat(Base):
116
  def __init__(self, key, model_name, **kwargs):
117
+ api_key = json.loads(key).get('api_key', '')
118
+ api_version = json.loads(key).get('api_version', '2024-02-01')
119
+ self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
120
  self.model_name = model_name
121
 
122
 
rag/llm/cv_model.py CHANGED
@@ -160,7 +160,9 @@ class GptV4(Base):
160
 
161
  class AzureGptV4(Base):
162
  def __init__(self, key, model_name, lang="Chinese", **kwargs):
163
- self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
 
 
164
  self.model_name = model_name
165
  self.lang = lang
166
 
 
160
 
161
  class AzureGptV4(Base):
162
  def __init__(self, key, model_name, lang="Chinese", **kwargs):
163
+ api_key = json.loads(key).get('api_key', '')
164
+ api_version = json.loads(key).get('api_version', '2024-02-01')
165
+ self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
166
  self.model_name = model_name
167
  self.lang = lang
168
 
rag/llm/embedding_model.py CHANGED
@@ -137,7 +137,9 @@ class LocalAIEmbed(Base):
137
  class AzureEmbed(OpenAIEmbed):
138
  def __init__(self, key, model_name, **kwargs):
139
  from openai.lib.azure import AzureOpenAI
140
- self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
 
 
141
  self.model_name = model_name
142
 
143
 
 
137
  class AzureEmbed(OpenAIEmbed):
138
  def __init__(self, key, model_name, **kwargs):
139
  from openai.lib.azure import AzureOpenAI
140
+ api_key = json.loads(key).get('api_key', '')
141
+ api_version = json.loads(key).get('api_version', '2024-02-01')
142
+ self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
143
  self.model_name = model_name
144
 
145
 
web/src/locales/en.ts CHANGED
@@ -581,6 +581,8 @@ The above is the content you need to summarize.`,
581
  GoogleRegionMessage: 'Please input Google Cloud Region',
582
  modelProvidersWarn:
583
  'Please add both embedding model and LLM in <b>Settings > Model providers</b> firstly.',
 
 
584
  },
585
  message: {
586
  registered: 'Registered!',
 
581
  GoogleRegionMessage: 'Please input Google Cloud Region',
582
  modelProvidersWarn:
583
  'Please add both embedding model and LLM in <b>Settings > Model providers</b> firstly.',
584
+ apiVersion: 'API-Version',
585
+ apiVersionMessage: 'Please input API version',
586
  },
587
  message: {
588
  registered: 'Registered!',
web/src/locales/zh.ts CHANGED
@@ -557,6 +557,8 @@ export default {
557
  GoogleRegionMessage: '请输入 Google Cloud 区域',
558
  modelProvidersWarn:
559
  '请首先在 <b>设置 > 模型提供商</b> 中添加嵌入模型和 LLM。',
 
 
560
  },
561
  message: {
562
  registered: '注册成功',
 
557
  GoogleRegionMessage: '请输入 Google Cloud 区域',
558
  modelProvidersWarn:
559
  '请首先在 <b>设置 > 模型提供商</b> 中添加嵌入模型和 LLM。',
560
+ apiVersion: 'API版本',
561
+ apiVersionMessage: '请输入API版本!',
562
  },
563
  message: {
564
  registered: '注册成功',
web/src/pages/user-setting/setting-model/azure-openai-modal/index.tsx ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useTranslate } from '@/hooks/common-hooks';
2
+ import { IModalProps } from '@/interfaces/common';
3
+ import { IAddLlmRequestBody } from '@/interfaces/request/llm';
4
+ import { Form, Input, Modal, Select, Switch } from 'antd';
5
+ import omit from 'lodash/omit';
6
+
7
+ type FieldType = IAddLlmRequestBody & {
8
+ api_version: string;
9
+ vision: boolean;
10
+ };
11
+
12
+ const { Option } = Select;
13
+
14
+ const AzureOpenAIModal = ({
15
+ visible,
16
+ hideModal,
17
+ onOk,
18
+ loading,
19
+ llmFactory,
20
+ }: IModalProps<IAddLlmRequestBody> & { llmFactory: string }) => {
21
+ const [form] = Form.useForm<FieldType>();
22
+
23
+ const { t } = useTranslate('setting');
24
+
25
+ const handleOk = async () => {
26
+ const values = await form.validateFields();
27
+ const modelType =
28
+ values.model_type === 'chat' && values.vision
29
+ ? 'image2text'
30
+ : values.model_type;
31
+
32
+ const data = {
33
+ ...omit(values, ['vision']),
34
+ model_type: modelType,
35
+ llm_factory: llmFactory,
36
+ };
37
+ console.info(data);
38
+
39
+ onOk?.(data);
40
+ };
41
+ const optionsMap = {
42
+ Default: [
43
+ { value: 'chat', label: 'chat' },
44
+ { value: 'embedding', label: 'embedding' },
45
+ { value: 'image2text', label: 'image2text' },
46
+ ],
47
+ };
48
+ const getOptions = (factory: string) => {
49
+ return optionsMap.Default;
50
+ };
51
+ return (
52
+ <Modal
53
+ title={t('addLlmTitle', { name: llmFactory })}
54
+ open={visible}
55
+ onOk={handleOk}
56
+ onCancel={hideModal}
57
+ okButtonProps={{ loading }}
58
+ >
59
+ <Form
60
+ name="basic"
61
+ style={{ maxWidth: 600 }}
62
+ autoComplete="off"
63
+ layout={'vertical'}
64
+ form={form}
65
+ >
66
+ <Form.Item<FieldType>
67
+ label={t('modelType')}
68
+ name="model_type"
69
+ initialValue={'embedding'}
70
+ rules={[{ required: true, message: t('modelTypeMessage') }]}
71
+ >
72
+ <Select placeholder={t('modelTypeMessage')}>
73
+ {getOptions(llmFactory).map((option) => (
74
+ <Option key={option.value} value={option.value}>
75
+ {option.label}
76
+ </Option>
77
+ ))}
78
+ </Select>
79
+ </Form.Item>
80
+ <Form.Item<FieldType>
81
+ label={t('addLlmBaseUrl')}
82
+ name="api_base"
83
+ rules={[{ required: true, message: t('baseUrlNameMessage') }]}
84
+ >
85
+ <Input placeholder={t('baseUrlNameMessage')} />
86
+ </Form.Item>
87
+ <Form.Item<FieldType>
88
+ label={t('apiKey')}
89
+ name="api_key"
90
+ rules={[{ required: false, message: t('apiKeyMessage') }]}
91
+ >
92
+ <Input placeholder={t('apiKeyMessage')} />
93
+ </Form.Item>
94
+ <Form.Item<FieldType>
95
+ label={t('modelName')}
96
+ name="llm_name"
97
+ initialValue="gpt-3.5-turbo"
98
+ rules={[{ required: true, message: t('modelNameMessage') }]}
99
+ >
100
+ <Input placeholder={t('modelNameMessage')} />
101
+ </Form.Item>
102
+ <Form.Item<FieldType>
103
+ label={t('apiVersion')}
104
+ name="api_version"
105
+ initialValue="2024-02-01"
106
+ rules={[{ required: false, message: t('apiVersionMessage') }]}
107
+ >
108
+ <Input placeholder={t('apiVersionMessage')} />
109
+ </Form.Item>
110
+ <Form.Item noStyle dependencies={['model_type']}>
111
+ {({ getFieldValue }) =>
112
+ getFieldValue('model_type') === 'chat' && (
113
+ <Form.Item
114
+ label={t('vision')}
115
+ valuePropName="checked"
116
+ name={'vision'}
117
+ >
118
+ <Switch />
119
+ </Form.Item>
120
+ )
121
+ }
122
+ </Form.Item>
123
+ </Form>
124
+ </Modal>
125
+ );
126
+ };
127
+
128
+ export default AzureOpenAIModal;
web/src/pages/user-setting/setting-model/hooks.ts CHANGED
@@ -353,6 +353,33 @@ export const useSubmitBedrock = () => {
353
  };
354
  };
355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  export const useHandleDeleteLlm = (llmFactory: string) => {
357
  const { deleteLlm } = useDeleteLlm();
358
  const showDeleteConfirm = useShowDeleteConfirm();
 
353
  };
354
  };
355
 
356
+ export const useSubmitAzure = () => {
357
+ const { addLlm, loading } = useAddLlm();
358
+ const {
359
+ visible: AzureAddingVisible,
360
+ hideModal: hideAzureAddingModal,
361
+ showModal: showAzureAddingModal,
362
+ } = useSetModalState();
363
+
364
+ const onAzureAddingOk = useCallback(
365
+ async (payload: IAddLlmRequestBody) => {
366
+ const ret = await addLlm(payload);
367
+ if (ret === 0) {
368
+ hideAzureAddingModal();
369
+ }
370
+ },
371
+ [hideAzureAddingModal, addLlm],
372
+ );
373
+
374
+ return {
375
+ AzureAddingLoading: loading,
376
+ onAzureAddingOk,
377
+ AzureAddingVisible,
378
+ hideAzureAddingModal,
379
+ showAzureAddingModal,
380
+ };
381
+ };
382
+
383
  export const useHandleDeleteLlm = (llmFactory: string) => {
384
  const { deleteLlm } = useDeleteLlm();
385
  const showDeleteConfirm = useShowDeleteConfirm();
web/src/pages/user-setting/setting-model/index.tsx CHANGED
@@ -29,6 +29,7 @@ import SettingTitle from '../components/setting-title';
29
  import { isLocalLlmFactory } from '../utils';
30
  import TencentCloudModal from './Tencent-modal';
31
  import ApiKeyModal from './api-key-modal';
 
32
  import BedrockModal from './bedrock-modal';
33
  import { IconMap } from './constant';
34
  import FishAudioModal from './fish-audio-modal';
@@ -37,6 +38,7 @@ import {
37
  useHandleDeleteFactory,
38
  useHandleDeleteLlm,
39
  useSubmitApiKey,
 
40
  useSubmitBedrock,
41
  useSubmitFishAudio,
42
  useSubmitGoogle,
@@ -109,7 +111,8 @@ const ModelCard = ({ item, clickApiKey }: IModelCardProps) => {
109
  item.name === 'BaiduYiyan' ||
110
  item.name === 'Fish Audio' ||
111
  item.name === 'Tencent Cloud' ||
112
- item.name === 'Google Cloud'
 
113
  ? t('addTheModel')
114
  : 'API-Key'}
115
  <SettingOutlined />
@@ -242,6 +245,14 @@ const UserSettingModel = () => {
242
  showBedrockAddingModal,
243
  } = useSubmitBedrock();
244
 
 
 
 
 
 
 
 
 
245
  const ModalMap = useMemo(
246
  () => ({
247
  Bedrock: showBedrockAddingModal,
@@ -252,6 +263,7 @@ const UserSettingModel = () => {
252
  'Fish Audio': showFishAudioAddingModal,
253
  'Tencent Cloud': showTencentCloudAddingModal,
254
  'Google Cloud': showGoogleAddingModal,
 
255
  }),
256
  [
257
  showBedrockAddingModal,
@@ -262,6 +274,7 @@ const UserSettingModel = () => {
262
  showyiyanAddingModal,
263
  showFishAudioAddingModal,
264
  showGoogleAddingModal,
 
265
  ],
266
  );
267
 
@@ -435,6 +448,13 @@ const UserSettingModel = () => {
435
  loading={bedrockAddingLoading}
436
  llmFactory={'Bedrock'}
437
  ></BedrockModal>
 
 
 
 
 
 
 
438
  </section>
439
  );
440
  };
 
29
  import { isLocalLlmFactory } from '../utils';
30
  import TencentCloudModal from './Tencent-modal';
31
  import ApiKeyModal from './api-key-modal';
32
+ import AzureOpenAIModal from './azure-openai-modal';
33
  import BedrockModal from './bedrock-modal';
34
  import { IconMap } from './constant';
35
  import FishAudioModal from './fish-audio-modal';
 
38
  useHandleDeleteFactory,
39
  useHandleDeleteLlm,
40
  useSubmitApiKey,
41
+ useSubmitAzure,
42
  useSubmitBedrock,
43
  useSubmitFishAudio,
44
  useSubmitGoogle,
 
111
  item.name === 'BaiduYiyan' ||
112
  item.name === 'Fish Audio' ||
113
  item.name === 'Tencent Cloud' ||
114
+ item.name === 'Google Cloud' ||
115
+ item.name === 'Azure OpenAI'
116
  ? t('addTheModel')
117
  : 'API-Key'}
118
  <SettingOutlined />
 
245
  showBedrockAddingModal,
246
  } = useSubmitBedrock();
247
 
248
+ const {
249
+ AzureAddingVisible,
250
+ hideAzureAddingModal,
251
+ showAzureAddingModal,
252
+ onAzureAddingOk,
253
+ AzureAddingLoading,
254
+ } = useSubmitAzure();
255
+
256
  const ModalMap = useMemo(
257
  () => ({
258
  Bedrock: showBedrockAddingModal,
 
263
  'Fish Audio': showFishAudioAddingModal,
264
  'Tencent Cloud': showTencentCloudAddingModal,
265
  'Google Cloud': showGoogleAddingModal,
266
+ 'Azure-OpenAI': showAzureAddingModal,
267
  }),
268
  [
269
  showBedrockAddingModal,
 
274
  showyiyanAddingModal,
275
  showFishAudioAddingModal,
276
  showGoogleAddingModal,
277
+ showAzureAddingModal,
278
  ],
279
  );
280
 
 
448
  loading={bedrockAddingLoading}
449
  llmFactory={'Bedrock'}
450
  ></BedrockModal>
451
+ <AzureOpenAIModal
452
+ visible={AzureAddingVisible}
453
+ hideModal={hideAzureAddingModal}
454
+ onOk={onAzureAddingOk}
455
+ loading={AzureAddingLoading}
456
+ llmFactory={'Azure-OpenAI'}
457
+ ></AzureOpenAIModal>
458
  </section>
459
  );
460
  };
web/src/pages/user-setting/setting-model/ollama-modal/index.tsx CHANGED
@@ -101,7 +101,7 @@ const OllamaModal = ({
101
  <Form.Item<FieldType>
102
  label={t('modelType')}
103
  name="model_type"
104
- initialValue={'chat'}
105
  rules={[{ required: true, message: t('modelTypeMessage') }]}
106
  >
107
  <Select placeholder={t('modelTypeMessage')}>
 
101
  <Form.Item<FieldType>
102
  label={t('modelType')}
103
  name="model_type"
104
+ initialValue={'embedding'}
105
  rules={[{ required: true, message: t('modelTypeMessage') }]}
106
  >
107
  <Select placeholder={t('modelTypeMessage')}>