H commited on
Commit
7c30742
·
1 Parent(s): da0feec

Add sequence2txt model.py (#1633)

Browse files

### What problem does this PR solve?

#1514

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

rag/llm/__init__.py CHANGED
@@ -17,7 +17,7 @@ from .embedding_model import *
17
  from .chat_model import *
18
  from .cv_model import *
19
  from .rerank_model import *
20
-
21
 
22
  EmbeddingModel = {
23
  "Ollama": OllamaEmbed,
@@ -81,3 +81,12 @@ RerankModel = {
81
  "Youdao": YoudaoRerank,
82
  "Xinference": XInferenceRerank
83
  }
 
 
 
 
 
 
 
 
 
 
17
  from .chat_model import *
18
  from .cv_model import *
19
  from .rerank_model import *
20
+ from .sequence2txt_model import *
21
 
22
  EmbeddingModel = {
23
  "Ollama": OllamaEmbed,
 
81
  "Youdao": YoudaoRerank,
82
  "Xinference": XInferenceRerank
83
  }
84
+
85
+
86
+ Seq2txtModel = {
87
+ "OpenAI": GPTSeq2txt,
88
+ "Tongyi-Qianwen": QWenSeq2txt,
89
+ "Ollama": OllamaSeq2txt,
90
+ "Azure-OpenAI": AzureSeq2txt,
91
+ "Xinference": XinferenceSeq2txt
92
+ }
rag/llm/sequence2txt_model.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from openai.lib.azure import AzureOpenAI
17
+ from zhipuai import ZhipuAI
18
+ import io
19
+ from abc import ABC
20
+ from ollama import Client
21
+ from openai import OpenAI
22
+ import os
23
+ import json
24
+ from rag.utils import num_tokens_from_string
25
+
26
+
27
+ class Base(ABC):
28
+ def __init__(self, key, model_name):
29
+ pass
30
+
31
+ def transcription(self, audio, **kwargs):
32
+ transcription = self.client.audio.transcriptions.create(
33
+ model=self.model_name,
34
+ file=audio,
35
+ response_format="text"
36
+ )
37
+ return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
38
+
39
+
40
+ class GPTSeq2txt(Base):
41
+ def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
42
+ if not base_url: base_url = "https://api.openai.com/v1"
43
+ self.client = OpenAI(api_key=key, base_url=base_url)
44
+ self.model_name = model_name
45
+
46
+
47
+ class QWenSeq2txt(Base):
48
+ def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs):
49
+ import dashscope
50
+ dashscope.api_key = key
51
+ self.model_name = model_name
52
+
53
+ def transcription(self, audio, format):
54
+ from http import HTTPStatus
55
+ from dashscope.audio.asr import Recognition
56
+
57
+ recognition = Recognition(model=self.model_name,
58
+ format=format,
59
+ sample_rate=16000,
60
+ callback=None)
61
+ result = recognition.call(audio)
62
+
63
+ ans = ""
64
+ if result.status_code == HTTPStatus.OK:
65
+ for sentence in result.get_sentence():
66
+ ans += str(sentence + '\n')
67
+ return ans, num_tokens_from_string(ans)
68
+
69
+ return "**ERROR**: " + result.message, 0
70
+
71
+
72
+ class OllamaSeq2txt(Base):
73
+ def __init__(self, key, model_name, lang="Chinese", **kwargs):
74
+ self.client = Client(host=kwargs["base_url"])
75
+ self.model_name = model_name
76
+ self.lang = lang
77
+
78
+
79
+ class AzureSeq2txt(Base):
80
+ def __init__(self, key, model_name, lang="Chinese", **kwargs):
81
+ self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
82
+ self.model_name = model_name
83
+ self.lang = lang
84
+
85
+
86
+ class XinferenceSeq2txt(Base):
87
+ def __init__(self, key, model_name="", base_url=""):
88
+ self.client = OpenAI(api_key="xxx", base_url=base_url)
89
+ self.model_name = model_name