JobSmithManipulation Kevin Hu commited on
Commit
44bea96
·
1 Parent(s): b2a5c0f

support sequence2txt and tts model in Xinference (#2696)

Browse files

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <[email protected]>

api/db/services/llm_service.py CHANGED
@@ -195,7 +195,7 @@ class LLMBundle(object):
195
  self.llm_name = llm_name
196
  self.mdl = TenantLLMService.model_instance(
197
  tenant_id, llm_type, llm_name, lang=lang)
198
- assert self.mdl, "Can't find mole for {}/{}/{}".format(
199
  tenant_id, llm_type, llm_name)
200
  self.max_length = 8192
201
  for lm in LLMService.query(llm_name=llm_name):
 
195
  self.llm_name = llm_name
196
  self.mdl = TenantLLMService.model_instance(
197
  tenant_id, llm_type, llm_name, lang=lang)
198
+ assert self.mdl, "Can't find model for {}/{}/{}".format(
199
  tenant_id, llm_type, llm_name)
200
  self.max_length = 8192
201
  for lm in LLMService.query(llm_name=llm_name):
rag/llm/__init__.py CHANGED
@@ -47,10 +47,9 @@ EmbeddingModel = {
47
  "Replicate": ReplicateEmbed,
48
  "BaiduYiyan": BaiduYiyanEmbed,
49
  "Voyage AI": VoyageEmbed,
50
- "HuggingFace":HuggingFaceEmbed,
51
  }
52
 
53
-
54
  CvModel = {
55
  "OpenAI": GptV4,
56
  "Azure-OpenAI": AzureGptV4,
@@ -64,14 +63,13 @@ CvModel = {
64
  "LocalAI": LocalAICV,
65
  "NVIDIA": NvidiaCV,
66
  "LM-Studio": LmStudioCV,
67
- "StepFun":StepFunCV,
68
  "OpenAI-API-Compatible": OpenAI_APICV,
69
  "TogetherAI": TogetherAICV,
70
  "01.AI": YiCV,
71
  "Tencent Hunyuan": HunyuanCV
72
  }
73
 
74
-
75
  ChatModel = {
76
  "OpenAI": GptTurbo,
77
  "Azure-OpenAI": AzureChat,
@@ -99,7 +97,7 @@ ChatModel = {
99
  "LeptonAI": LeptonAIChat,
100
  "TogetherAI": TogetherAIChat,
101
  "PerfXCloud": PerfXCloudChat,
102
- "Upstage":UpstageChat,
103
  "novita.ai": NovitaAIChat,
104
  "SILICONFLOW": SILICONFLOWChat,
105
  "01.AI": YiChat,
@@ -111,7 +109,6 @@ ChatModel = {
111
  "Google Cloud": GoogleChat,
112
  }
113
 
114
-
115
  RerankModel = {
116
  "BAAI": DefaultRerank,
117
  "Jina": JinaRerank,
@@ -127,11 +124,9 @@ RerankModel = {
127
  "Voyage AI": VoyageRerank
128
  }
129
 
130
-
131
  Seq2txtModel = {
132
  "OpenAI": GPTSeq2txt,
133
  "Tongyi-Qianwen": QWenSeq2txt,
134
- "Ollama": OllamaSeq2txt,
135
  "Azure-OpenAI": AzureSeq2txt,
136
  "Xinference": XinferenceSeq2txt,
137
  "Tencent Cloud": TencentCloudSeq2txt
@@ -140,6 +135,7 @@ Seq2txtModel = {
140
  TTSModel = {
141
  "Fish Audio": FishAudioTTS,
142
  "Tongyi-Qianwen": QwenTTS,
143
- "OpenAI":OpenAITTS,
144
- "XunFei Spark":SparkTTS
145
- }
 
 
47
  "Replicate": ReplicateEmbed,
48
  "BaiduYiyan": BaiduYiyanEmbed,
49
  "Voyage AI": VoyageEmbed,
50
+ "HuggingFace": HuggingFaceEmbed,
51
  }
52
 
 
53
  CvModel = {
54
  "OpenAI": GptV4,
55
  "Azure-OpenAI": AzureGptV4,
 
63
  "LocalAI": LocalAICV,
64
  "NVIDIA": NvidiaCV,
65
  "LM-Studio": LmStudioCV,
66
+ "StepFun": StepFunCV,
67
  "OpenAI-API-Compatible": OpenAI_APICV,
68
  "TogetherAI": TogetherAICV,
69
  "01.AI": YiCV,
70
  "Tencent Hunyuan": HunyuanCV
71
  }
72
 
 
73
  ChatModel = {
74
  "OpenAI": GptTurbo,
75
  "Azure-OpenAI": AzureChat,
 
97
  "LeptonAI": LeptonAIChat,
98
  "TogetherAI": TogetherAIChat,
99
  "PerfXCloud": PerfXCloudChat,
100
+ "Upstage": UpstageChat,
101
  "novita.ai": NovitaAIChat,
102
  "SILICONFLOW": SILICONFLOWChat,
103
  "01.AI": YiChat,
 
109
  "Google Cloud": GoogleChat,
110
  }
111
 
 
112
  RerankModel = {
113
  "BAAI": DefaultRerank,
114
  "Jina": JinaRerank,
 
124
  "Voyage AI": VoyageRerank
125
  }
126
 
 
127
  Seq2txtModel = {
128
  "OpenAI": GPTSeq2txt,
129
  "Tongyi-Qianwen": QWenSeq2txt,
 
130
  "Azure-OpenAI": AzureSeq2txt,
131
  "Xinference": XinferenceSeq2txt,
132
  "Tencent Cloud": TencentCloudSeq2txt
 
135
  TTSModel = {
136
  "Fish Audio": FishAudioTTS,
137
  "Tongyi-Qianwen": QwenTTS,
138
+ "OpenAI": OpenAITTS,
139
+ "XunFei Spark": SparkTTS,
140
+ "Xinference": XinferenceTTS,
141
+ }
rag/llm/sequence2txt_model.py CHANGED
@@ -13,6 +13,7 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
16
  from openai.lib.azure import AzureOpenAI
17
  from zhipuai import ZhipuAI
18
  import io
@@ -25,6 +26,7 @@ from rag.utils import num_tokens_from_string
25
  import base64
26
  import re
27
 
 
28
  class Base(ABC):
29
  def __init__(self, key, model_name):
30
  pass
@@ -36,8 +38,8 @@ class Base(ABC):
36
  response_format="text"
37
  )
38
  return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
39
-
40
- def audio2base64(self,audio):
41
  if isinstance(audio, bytes):
42
  return base64.b64encode(audio).decode("utf-8")
43
  if isinstance(audio, io.BytesIO):
@@ -77,13 +79,6 @@ class QWenSeq2txt(Base):
77
  return "**ERROR**: " + result.message, 0
78
 
79
 
80
- class OllamaSeq2txt(Base):
81
- def __init__(self, key, model_name, lang="Chinese", **kwargs):
82
- self.client = Client(host=kwargs["base_url"])
83
- self.model_name = model_name
84
- self.lang = lang
85
-
86
-
87
  class AzureSeq2txt(Base):
88
  def __init__(self, key, model_name, lang="Chinese", **kwargs):
89
  self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
@@ -92,16 +87,53 @@ class AzureSeq2txt(Base):
92
 
93
 
94
  class XinferenceSeq2txt(Base):
95
- def __init__(self, key, model_name="", base_url=""):
96
- if base_url.split("/")[-1] != "v1":
97
- base_url = os.path.join(base_url, "v1")
98
- self.client = OpenAI(api_key="xxx", base_url=base_url)
99
  self.model_name = model_name
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  class TencentCloudSeq2txt(Base):
103
  def __init__(
104
- self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
105
  ):
106
  from tencentcloud.common import credential
107
  from tencentcloud.asr.v20190614 import asr_client
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ import requests
17
  from openai.lib.azure import AzureOpenAI
18
  from zhipuai import ZhipuAI
19
  import io
 
26
  import base64
27
  import re
28
 
29
+
30
  class Base(ABC):
31
  def __init__(self, key, model_name):
32
  pass
 
38
  response_format="text"
39
  )
40
  return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
41
+
42
+ def audio2base64(self, audio):
43
  if isinstance(audio, bytes):
44
  return base64.b64encode(audio).decode("utf-8")
45
  if isinstance(audio, io.BytesIO):
 
79
  return "**ERROR**: " + result.message, 0
80
 
81
 
 
 
 
 
 
 
 
82
  class AzureSeq2txt(Base):
83
  def __init__(self, key, model_name, lang="Chinese", **kwargs):
84
  self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
 
87
 
88
 
89
  class XinferenceSeq2txt(Base):
90
+ def __init__(self,key,model_name="whisper-small",**kwargs):
91
+ self.base_url = kwargs.get('base_url', None)
 
 
92
  self.model_name = model_name
93
 
94
+ def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
95
+ if isinstance(audio, str):
96
+ audio_file = open(audio, 'rb')
97
+ audio_data = audio_file.read()
98
+ audio_file_name = audio.split("/")[-1]
99
+ else:
100
+ audio_data = audio
101
+ audio_file_name = "audio.wav"
102
+
103
+ payload = {
104
+ "model": self.model_name,
105
+ "language": language,
106
+ "prompt": prompt,
107
+ "response_format": response_format,
108
+ "temperature": temperature
109
+ }
110
+
111
+ files = {
112
+ "file": (audio_file_name, audio_data, 'audio/wav')
113
+ }
114
+
115
+ try:
116
+ response = requests.post(
117
+ f"{self.base_url}/v1/audio/transcriptions",
118
+ files=files,
119
+ data=payload
120
+ )
121
+ response.raise_for_status()
122
+ result = response.json()
123
+
124
+ if 'text' in result:
125
+ transcription_text = result['text'].strip()
126
+ return transcription_text, num_tokens_from_string(transcription_text)
127
+ else:
128
+ return "**ERROR**: Failed to retrieve transcription.", 0
129
+
130
+ except requests.exceptions.RequestException as e:
131
+ return f"**ERROR**: {str(e)}", 0
132
+
133
 
134
  class TencentCloudSeq2txt(Base):
135
  def __init__(
136
+ self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
137
  ):
138
  from tencentcloud.common import credential
139
  from tencentcloud.asr.v20190614 import asr_client
rag/llm/tts_model.py CHANGED
@@ -297,3 +297,36 @@ class SparkTTS:
297
  break
298
  status_code = 1
299
  yield audio_chunk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  break
298
  status_code = 1
299
  yield audio_chunk
300
+
301
+
302
+
303
+
304
+ class XinferenceTTS:
305
+ def __init__(self, key, model_name, **kwargs):
306
+ self.base_url = kwargs.get("base_url", None)
307
+ self.model_name = model_name
308
+ self.headers = {
309
+ "accept": "application/json",
310
+ "Content-Type": "application/json"
311
+ }
312
+
313
+ def tts(self, text, voice="中文女", stream=True):
314
+ payload = {
315
+ "model": self.model_name,
316
+ "input": text,
317
+ "voice": voice
318
+ }
319
+
320
+ response = requests.post(
321
+ f"{self.base_url}/v1/audio/speech",
322
+ headers=self.headers,
323
+ json=payload,
324
+ stream=stream
325
+ )
326
+
327
+ if response.status_code != 200:
328
+ raise Exception(f"**Error**: {response.status_code}, {response.text}")
329
+
330
+ for chunk in response.iter_content(chunk_size=1024):
331
+ if chunk:
332
+ yield chunk
web/src/pages/user-setting/setting-model/ollama-modal/index.tsx CHANGED
@@ -53,6 +53,26 @@ const OllamaModal = ({
53
  const url =
54
  llmFactoryToUrlMap[llmFactory as LlmFactory] ||
55
  'https://github.com/infiniflow/ragflow/blob/main/docs/guides/deploy_local_llm.mdx';
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  return (
57
  <Modal
58
  title={t('addLlmTitle', { name: llmFactory })}
@@ -85,18 +105,11 @@ const OllamaModal = ({
85
  rules={[{ required: true, message: t('modelTypeMessage') }]}
86
  >
87
  <Select placeholder={t('modelTypeMessage')}>
88
- {llmFactory === 'HuggingFace' ? (
89
- <Option value="embedding">embedding</Option>
90
- ) : (
91
- <>
92
- <Option value="chat">chat</Option>
93
- <Option value="embedding">embedding</Option>
94
- <Option value="rerank">rerank</Option>
95
- <Option value="image2text">image2text</Option>
96
- <Option value="audio2text">audio2text</Option>
97
- <Option value="text2andio">text2andio</Option>
98
- </>
99
- )}
100
  </Select>
101
  </Form.Item>
102
  <Form.Item<FieldType>
 
53
  const url =
54
  llmFactoryToUrlMap[llmFactory as LlmFactory] ||
55
  'https://github.com/infiniflow/ragflow/blob/main/docs/guides/deploy_local_llm.mdx';
56
+ const optionsMap = {
57
+ HuggingFace: [{ value: 'embedding', label: 'embedding' }],
58
+ Xinference: [
59
+ { value: 'chat', label: 'chat' },
60
+ { value: 'embedding', label: 'embedding' },
61
+ { value: 'rerank', label: 'rerank' },
62
+ { value: 'image2text', label: 'image2text' },
63
+ { value: 'speech2text', label: 'sequence2text' },
64
+ { value: 'tts', label: 'tts' },
65
+ ],
66
+ Default: [
67
+ { value: 'chat', label: 'chat' },
68
+ { value: 'embedding', label: 'embedding' },
69
+ { value: 'rerank', label: 'rerank' },
70
+ { value: 'image2text', label: 'image2text' },
71
+ ],
72
+ };
73
+ const getOptions = (factory: string) => {
74
+ return optionsMap[factory as keyof typeof optionsMap] || optionsMap.Default;
75
+ };
76
  return (
77
  <Modal
78
  title={t('addLlmTitle', { name: llmFactory })}
 
105
  rules={[{ required: true, message: t('modelTypeMessage') }]}
106
  >
107
  <Select placeholder={t('modelTypeMessage')}>
108
+ {getOptions(llmFactory).map((option) => (
109
+ <Option key={option.value} value={option.value}>
110
+ {option.label}
111
+ </Option>
112
+ ))}
 
 
 
 
 
 
 
113
  </Select>
114
  </Form.Item>
115
  <Form.Item<FieldType>