tomxxie commited on
Commit
e4f8633
·
1 Parent(s): 50836ff

实现ui界面

Browse files
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # 默认忽略的文件
2
+ /shelf/
3
+ /workspace.xml
.idea/OSUM.iml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$">
5
+ <excludeFolder url="file://$MODULE_DIR$/venv" />
6
+ </content>
7
+ <orderEntry type="inheritedJdk" />
8
+ <orderEntry type="sourceFolder" forTests="false" />
9
+ </component>
10
+ <component name="PyDocumentationSettings">
11
+ <option name="format" value="PLAIN" />
12
+ <option name="myDocStringFormat" value="Plain" />
13
+ </component>
14
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="207">
8
+ <item index="0" class="java.lang.String" itemvalue="flake8-executable" />
9
+ <item index="1" class="java.lang.String" itemvalue="cpplint" />
10
+ <item index="2" class="java.lang.String" itemvalue="pyflakes" />
11
+ <item index="3" class="java.lang.String" itemvalue="textgrid" />
12
+ <item index="4" class="java.lang.String" itemvalue="flake8-pyi" />
13
+ <item index="5" class="java.lang.String" itemvalue="clang-format" />
14
+ <item index="6" class="java.lang.String" itemvalue="pre-commit" />
15
+ <item index="7" class="java.lang.String" itemvalue="pycodestyle" />
16
+ <item index="8" class="java.lang.String" itemvalue="openai-whisper" />
17
+ <item index="9" class="java.lang.String" itemvalue="flake8-bugbear" />
18
+ <item index="10" class="java.lang.String" itemvalue="flake8" />
19
+ <item index="11" class="java.lang.String" itemvalue="deepspeed" />
20
+ <item index="12" class="java.lang.String" itemvalue="mccabe" />
21
+ <item index="13" class="java.lang.String" itemvalue="tensorboardX" />
22
+ <item index="14" class="java.lang.String" itemvalue="librosa" />
23
+ <item index="15" class="java.lang.String" itemvalue="langid" />
24
+ <item index="16" class="java.lang.String" itemvalue="flake8-comprehensions" />
25
+ <item index="17" class="java.lang.String" itemvalue="editdistance" />
26
+ <item index="18" class="java.lang.String" itemvalue="scipy" />
27
+ <item index="19" class="java.lang.String" itemvalue="tensorboard" />
28
+ <item index="20" class="java.lang.String" itemvalue="soundfile" />
29
+ <item index="21" class="java.lang.String" itemvalue="matplotlib" />
30
+ <item index="22" class="java.lang.String" itemvalue="torch" />
31
+ <item index="23" class="java.lang.String" itemvalue="Levenshtein" />
32
+ <item index="24" class="java.lang.String" itemvalue="visdom" />
33
+ <item index="25" class="java.lang.String" itemvalue="sox" />
34
+ <item index="26" class="java.lang.String" itemvalue="pypinyin" />
35
+ <item index="27" class="java.lang.String" itemvalue="onnxsim" />
36
+ <item index="28" class="java.lang.String" itemvalue="kaldi-decoder" />
37
+ <item index="29" class="java.lang.String" itemvalue="black" />
38
+ <item index="30" class="java.lang.String" itemvalue="dill" />
39
+ <item index="31" class="java.lang.String" itemvalue="onnxconverter_common" />
40
+ <item index="32" class="java.lang.String" itemvalue="typeguard" />
41
+ <item index="33" class="java.lang.String" itemvalue="kaldialign" />
42
+ <item index="34" class="java.lang.String" itemvalue="onnxruntime" />
43
+ <item index="35" class="java.lang.String" itemvalue="kaldifst" />
44
+ <item index="36" class="java.lang.String" itemvalue="num2words" />
45
+ <item index="37" class="java.lang.String" itemvalue="pycantonese" />
46
+ <item index="38" class="java.lang.String" itemvalue="isort" />
47
+ <item index="39" class="java.lang.String" itemvalue="onnxoptimizer" />
48
+ <item index="40" class="java.lang.String" itemvalue="kaldilm" />
49
+ <item index="41" class="java.lang.String" itemvalue="onnx" />
50
+ <item index="42" class="java.lang.String" itemvalue="modelscope" />
51
+ <item index="43" class="java.lang.String" itemvalue="numba" />
52
+ <item index="44" class="java.lang.String" itemvalue="bs4" />
53
+ <item index="45" class="java.lang.String" itemvalue="jieba" />
54
+ <item index="46" class="java.lang.String" itemvalue="scikit-learn" />
55
+ <item index="47" class="java.lang.String" itemvalue="dgl" />
56
+ <item index="48" class="java.lang.String" itemvalue="setuptools" />
57
+ <item index="49" class="java.lang.String" itemvalue="lhotse" />
58
+ <item index="50" class="java.lang.String" itemvalue="requests" />
59
+ <item index="51" class="java.lang.String" itemvalue="numpy" />
60
+ <item index="52" class="java.lang.String" itemvalue="datasets" />
61
+ <item index="53" class="java.lang.String" itemvalue="torchvision" />
62
+ <item index="54" class="java.lang.String" itemvalue="selenium" />
63
+ <item index="55" class="java.lang.String" itemvalue="tensorboardx" />
64
+ <item index="56" class="java.lang.String" itemvalue="lxml" />
65
+ <item index="57" class="java.lang.String" itemvalue="torchaudio" />
66
+ <item index="58" class="java.lang.String" itemvalue="transformers" />
67
+ <item index="59" class="java.lang.String" itemvalue="dataclasses" />
68
+ <item index="60" class="java.lang.String" itemvalue="loralib" />
69
+ <item index="61" class="java.lang.String" itemvalue="netron" />
70
+ <item index="62" class="java.lang.String" itemvalue="tqdm" />
71
+ <item index="63" class="java.lang.String" itemvalue="regex" />
72
+ <item index="64" class="java.lang.String" itemvalue="zhconv" />
73
+ <item index="65" class="java.lang.String" itemvalue="pillow" />
74
+ <item index="66" class="java.lang.String" itemvalue="espnet" />
75
+ <item index="67" class="java.lang.String" itemvalue="flask" />
76
+ <item index="68" class="java.lang.String" itemvalue="einops" />
77
+ <item index="69" class="java.lang.String" itemvalue="joblib" />
78
+ <item index="70" class="java.lang.String" itemvalue="fairseq" />
79
+ <item index="71" class="java.lang.String" itemvalue="huggingface-hub" />
80
+ <item index="72" class="java.lang.String" itemvalue="xformers" />
81
+ <item index="73" class="java.lang.String" itemvalue="nvidia-cuda-cupti-cu11" />
82
+ <item index="74" class="java.lang.String" itemvalue="fsspec" />
83
+ <item index="75" class="java.lang.String" itemvalue="nvidia-cusolver-cu11" />
84
+ <item index="76" class="java.lang.String" itemvalue="nvidia-curand-cu11" />
85
+ <item index="77" class="java.lang.String" itemvalue="filelock" />
86
+ <item index="78" class="java.lang.String" itemvalue="lit" />
87
+ <item index="79" class="java.lang.String" itemvalue="pip" />
88
+ <item index="80" class="java.lang.String" itemvalue="safetensors" />
89
+ <item index="81" class="java.lang.String" itemvalue="sentencepiece" />
90
+ <item index="82" class="java.lang.String" itemvalue="urllib" />
91
+ <item index="83" class="java.lang.String" itemvalue="certifi" />
92
+ <item index="84" class="java.lang.String" itemvalue="nvidia-cufft-cu11" />
93
+ <item index="85" class="java.lang.String" itemvalue="accelerate" />
94
+ <item index="86" class="java.lang.String" itemvalue="nvidia-cuda-runtime-cu11" />
95
+ <item index="87" class="java.lang.String" itemvalue="sacrebleu" />
96
+ <item index="88" class="java.lang.String" itemvalue="sympy" />
97
+ <item index="89" class="java.lang.String" itemvalue="tokenizers" />
98
+ <item index="90" class="java.lang.String" itemvalue="Jinja" />
99
+ <item index="91" class="java.lang.String" itemvalue="portalocker" />
100
+ <item index="92" class="java.lang.String" itemvalue="pydantic" />
101
+ <item index="93" class="java.lang.String" itemvalue="triton" />
102
+ <item index="94" class="java.lang.String" itemvalue="nvidia-cuda-nvrtc-cu11" />
103
+ <item index="95" class="java.lang.String" itemvalue="bitsandbytes" />
104
+ <item index="96" class="java.lang.String" itemvalue="omegaconf" />
105
+ <item index="97" class="java.lang.String" itemvalue="psutil" />
106
+ <item index="98" class="java.lang.String" itemvalue="platformdirs" />
107
+ <item index="99" class="java.lang.String" itemvalue="ninja" />
108
+ <item index="100" class="java.lang.String" itemvalue="nvidia-nvtx-cu11" />
109
+ <item index="101" class="java.lang.String" itemvalue="peft" />
110
+ <item index="102" class="java.lang.String" itemvalue="charset-normalizer" />
111
+ <item index="103" class="java.lang.String" itemvalue="hjson" />
112
+ <item index="104" class="java.lang.String" itemvalue="msgpack" />
113
+ <item index="105" class="java.lang.String" itemvalue="idna" />
114
+ <item index="106" class="java.lang.String" itemvalue="networkx" />
115
+ <item index="107" class="java.lang.String" itemvalue="more-itertools" />
116
+ <item index="108" class="java.lang.String" itemvalue="nvidia-ml-py" />
117
+ <item index="109" class="java.lang.String" itemvalue="antlr4-python3-runtime" />
118
+ <item index="110" class="java.lang.String" itemvalue="wcwidth" />
119
+ <item index="111" class="java.lang.String" itemvalue="py-cpuinfo" />
120
+ <item index="112" class="java.lang.String" itemvalue="llvmlite" />
121
+ <item index="113" class="java.lang.String" itemvalue="nvidia-cublas-cu11" />
122
+ <item index="114" class="java.lang.String" itemvalue="nvidia-cusparse-cu11" />
123
+ <item index="115" class="java.lang.String" itemvalue="nvidia-nccl-cu11" />
124
+ <item index="116" class="java.lang.String" itemvalue="nvidia-cudnn-cu11" />
125
+ <item index="117" class="java.lang.String" itemvalue="gxl-ai-utils" />
126
+ <item index="118" class="java.lang.String" itemvalue="threadpoolct" />
127
+ <item index="119" class="java.lang.String" itemvalue="Cython" />
128
+ <item index="120" class="java.lang.String" itemvalue="lazy_loader" />
129
+ <item index="121" class="java.lang.String" itemvalue="soxr" />
130
+ <item index="122" class="java.lang.String" itemvalue="wheel" />
131
+ <item index="123" class="java.lang.String" itemvalue="gpustat" />
132
+ <item index="124" class="java.lang.String" itemvalue="s3prl" />
133
+ <item index="125" class="java.lang.String" itemvalue="blessed" />
134
+ <item index="126" class="java.lang.String" itemvalue="ffmpeg-python" />
135
+ <item index="127" class="java.lang.String" itemvalue="pooch" />
136
+ <item index="128" class="java.lang.String" itemvalue="hydra-core" />
137
+ <item index="129" class="java.lang.String" itemvalue="future" />
138
+ <item index="130" class="java.lang.String" itemvalue="cmake" />
139
+ <item index="131" class="java.lang.String" itemvalue="typing_extensions" />
140
+ <item index="132" class="java.lang.String" itemvalue="imblearn" />
141
+ <item index="133" class="java.lang.String" itemvalue="imbalanced-learn" />
142
+ <item index="134" class="java.lang.String" itemvalue="onnxconverter-common" />
143
+ <item index="135" class="java.lang.String" itemvalue="panda" />
144
+ <item index="136" class="java.lang.String" itemvalue="emoji" />
145
+ <item index="137" class="java.lang.String" itemvalue="nltk" />
146
+ <item index="138" class="java.lang.String" itemvalue="nvidia-cuda-cupti-cu12" />
147
+ <item index="139" class="java.lang.String" itemvalue="PyYAML" />
148
+ <item index="140" class="java.lang.String" itemvalue="nvidia-cufft-cu12" />
149
+ <item index="141" class="java.lang.String" itemvalue="cycler" />
150
+ <item index="142" class="java.lang.String" itemvalue="httptools" />
151
+ <item index="143" class="java.lang.String" itemvalue="frozenlist" />
152
+ <item index="144" class="java.lang.String" itemvalue="nvidia-cusolver-cu12" />
153
+ <item index="145" class="java.lang.String" itemvalue="nvidia-curand-cu12" />
154
+ <item index="146" class="java.lang.String" itemvalue="Pygments" />
155
+ <item index="147" class="java.lang.String" itemvalue="aliyun-python-sdk-core" />
156
+ <item index="148" class="java.lang.String" itemvalue="anyio" />
157
+ <item index="149" class="java.lang.String" itemvalue="multiprocess" />
158
+ <item index="150" class="java.lang.String" itemvalue="nvidia-cuda-runtime-cu12" />
159
+ <item index="151" class="java.lang.String" itemvalue="pyparsing" />
160
+ <item index="152" class="java.lang.String" itemvalue="vocos" />
161
+ <item index="153" class="java.lang.String" itemvalue="Markdown" />
162
+ <item index="154" class="java.lang.String" itemvalue="uvloop" />
163
+ <item index="155" class="java.lang.String" itemvalue="xxhash" />
164
+ <item index="156" class="java.lang.String" itemvalue="nvidia-cuda-nvrtc-cu12" />
165
+ <item index="157" class="java.lang.String" itemvalue="kaldifeat" />
166
+ <item index="158" class="java.lang.String" itemvalue="Werkzeug" />
167
+ <item index="159" class="java.lang.String" itemvalue="k2" />
168
+ <item index="160" class="java.lang.String" itemvalue="aiohappyeyeballs" />
169
+ <item index="161" class="java.lang.String" itemvalue="cryptography" />
170
+ <item index="162" class="java.lang.String" itemvalue="kiwisolver" />
171
+ <item index="163" class="java.lang.String" itemvalue="attrs" />
172
+ <item index="164" class="java.lang.String" itemvalue="coloredlogs" />
173
+ <item index="165" class="java.lang.String" itemvalue="contourpy" />
174
+ <item index="166" class="java.lang.String" itemvalue="simplejson" />
175
+ <item index="167" class="java.lang.String" itemvalue="fonttools" />
176
+ <item index="168" class="java.lang.String" itemvalue="flatbuffers" />
177
+ <item index="169" class="java.lang.String" itemvalue="nvidia-cublas-cu12" />
178
+ <item index="170" class="java.lang.String" itemvalue="nvidia-nvtx-cu12" />
179
+ <item index="171" class="java.lang.String" itemvalue="cytoolz" />
180
+ <item index="172" class="java.lang.String" itemvalue="propcache" />
181
+ <item index="173" class="java.lang.String" itemvalue="oss2" />
182
+ <item index="174" class="java.lang.String" itemvalue="vector-quantize-pytorch" />
183
+ <item index="175" class="java.lang.String" itemvalue="encodec" />
184
+ <item index="176" class="java.lang.String" itemvalue="einx" />
185
+ <item index="177" class="java.lang.String" itemvalue="nvidia-nvjitlink-cu12" />
186
+ <item index="178" class="java.lang.String" itemvalue="cffi" />
187
+ <item index="179" class="java.lang.String" itemvalue="nvidia-cusparse-cu12" />
188
+ <item index="180" class="java.lang.String" itemvalue="crcmod" />
189
+ <item index="181" class="java.lang.String" itemvalue="pybase16384" />
190
+ <item index="182" class="java.lang.String" itemvalue="bitarray" />
191
+ <item index="183" class="java.lang.String" itemvalue="nvidia-nccl-cu12" />
192
+ <item index="184" class="java.lang.String" itemvalue="sniffio" />
193
+ <item index="185" class="java.lang.String" itemvalue="urllib3" />
194
+ <item index="186" class="java.lang.String" itemvalue="flashlight" />
195
+ <item index="187" class="java.lang.String" itemvalue="pyarrow" />
196
+ <item index="188" class="java.lang.String" itemvalue="rich" />
197
+ <item index="189" class="java.lang.String" itemvalue="nvidia-cudnn-cu12" />
198
+ <item index="190" class="java.lang.String" itemvalue="addict" />
199
+ <item index="191" class="java.lang.String" itemvalue="gxl_ai_utils" />
200
+ <item index="192" class="java.lang.String" itemvalue="jmespath" />
201
+ <item index="193" class="java.lang.String" itemvalue="toolz" />
202
+ <item index="194" class="java.lang.String" itemvalue="aliyun-python-sdk-kms" />
203
+ <item index="195" class="java.lang.String" itemvalue="sanic" />
204
+ <item index="196" class="java.lang.String" itemvalue="aiohttp" />
205
+ <item index="197" class="java.lang.String" itemvalue="multidict" />
206
+ <item index="198" class="java.lang.String" itemvalue="grpcio" />
207
+ <item index="199" class="java.lang.String" itemvalue="yarl" />
208
+ <item index="200" class="java.lang.String" itemvalue="frozendict" />
209
+ <item index="201" class="java.lang.String" itemvalue="pycryptodome" />
210
+ <item index="202" class="java.lang.String" itemvalue="pytz" />
211
+ <item index="203" class="java.lang.String" itemvalue="aiosignal" />
212
+ <item index="204" class="java.lang.String" itemvalue="ujson" />
213
+ <item index="205" class="java.lang.String" itemvalue="cloudpickle" />
214
+ <item index="206" class="java.lang.String" itemvalue="ml-dtypes" />
215
+ </list>
216
+ </value>
217
+ </option>
218
+ </inspection_tool>
219
+ </profile>
220
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Python 3.12 (OSUM)" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 (OSUM)" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/OSUM.iml" filepath="$PROJECT_DIR$/.idea/OSUM.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
app.py CHANGED
@@ -1,64 +1,257 @@
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
27
 
28
- response = ""
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
-
62
-
63
- if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import time
4
+
5
  import gradio as gr
6
+ import os
7
 
8
+ import sys
9
+
10
+
11
+ sys.path.insert(0, '../../../../')
12
+ # from gxl_ai_utils.utils import utils_file
13
+ # from wenet.utils.init_tokenizer import init_tokenizer
14
+ # from gxl_ai_utils.config.gxl_config import GxlNode
15
+ # from wenet.utils.init_model import init_model
16
+ import logging
17
+ # import librosa
18
+ # import torch
19
+ # import torchaudio
20
+ # import numpy as np
21
+
22
+ # 将图片转换为 Base64
23
+ with open("./实验室.png", "rb") as image_file:
24
+ encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
25
+
26
+ # with open("./cat.jpg", "rb") as image_file:
27
+ # encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
28
+
29
+ # 自定义CSS样式
30
+ custom_css = """
31
+ /* 自定义CSS样式 */
32
  """
 
 
 
33
 
34
+ # 任务提示映射
35
+ TASK_PROMPT_MAPPING = {
36
+ "ASR (Automatic Speech Recognition)": "执行语音识别任务,将音频转换为文字。",
37
+ "SRWT (Speech Recognition with Timestamps)": "请转录音频内容,并为每个英文词汇及其对应的中文翻译标注出精确到0.1秒的起止时间,时间范围用<>括起来。",
38
+ "VED (Vocal Event Detection)(类别:laugh,cough,cry,screaming,sigh,throat clearing,sneeze,other)": "请将音频转录为文字记录,并在记录末尾标注<音频事件>标签,音频事件共8种:laugh,cough,cry,screaming,sigh,throat clearing,sneeze,other。",
39
+ "SER (Speech Emotion Recognition)(类别:sad,anger,neutral,happy,surprise,fear,disgust,和other)": "请将音频内容转录成文字记录,并在记录末尾标注<情感>标签,情感共8种:sad,anger,neutral,happy,surprise,fear,disgust,和other。",
40
+ "SSR (Speaking Style Recognition)(类别:新闻科普,恐怖故事,童话故事,客服,诗歌散文,有声书,日常口语,其他)": "请将音频内容进行文字转录,并在最后添加<风格>标签,标签共8种:新闻科普、恐怖故事、童话故事、客服、诗歌散文、有声书、日常口语、其他。",
41
+ "SGC (Speaker Gender Classification)(类别:female,male)": "请将音频转录为文本,并在文本结尾处标注<性别>标签,性别为female或male。",
42
+ "SAP (Speaker Age Prediction)(类别:child、adult和old)": "请将音频转录为文本,并在文本结尾处标注<年龄>标签,年龄划分为child、adult和old三种。",
43
+ "STTC (Speech to Text Chat)": "首先将语音转录为文字,然后对语音内容进行回复,转录和文字之间使用<开始回答>分割。"
44
+ }
45
 
46
+ gpu_id = 4
47
+ # def init_model_my():
48
+ # logging.basicConfig(level=logging.DEBUG,
49
+ # format='%(asctime)s %(levelname)s %(message)s')
50
+ # config_path = "/home/node54_tmpdata/xlgeng/code/wenet_undersdand_and_speech_xlgeng/examples/wenetspeech/whisper/exp/update_data/epoch_1_with_token/epoch_11.yaml"
51
+ # #config_path = "/home/work_nfs15/asr_data/ckpt/understanding_model/step_24999.yaml"
52
+ #
53
+ # checkpoint_path = "/home/node54_tmpdata/xlgeng/code/wenet_undersdand_and_speech_xlgeng/examples/wenetspeech/whisper/exp/update_data/epoch_1_with_token/epoch_11.pt"
54
+ # checkpoint_path = "/home/work_nfs15/asr_data/ckpt/understanding_model/epoch4/step_21249.pt"
55
+ # checkpoint_path = "/home/work_nfs15/asr_data/ckpt/understanding_model/epoch_13_with_asr-chat_full_data/step_32499/step_32499.pt"
56
+ # args = GxlNode({
57
+ # "checkpoint": checkpoint_path,
58
+ # })
59
+ # configs = utils_file.load_dict_from_yaml(config_path)
60
+ # model, configs = init_model(args, configs)
61
+ # model = model.cuda(gpu_id)
62
+ # tokenizer = init_tokenizer(configs)
63
+ # print(model)
64
+ # return model, tokenizer
65
+ #
66
+ # model, tokenizer = init_model_my()
67
+ #
68
+ # def do_resample(input_wav_path, output_wav_path):
69
+ # """"""
70
+ # print(f'input_wav_path: {input_wav_path}, output_wav_path: {output_wav_path}')
71
+ # waveform, sample_rate = torchaudio.load(input_wav_path)
72
+ # # 检查音频的维度
73
+ # num_channels = waveform.shape[0]
74
+ # # 如果音频是多通道的,则进行通道平均
75
+ # if num_channels > 1:
76
+ # waveform = torch.mean(waveform, dim=0, keepdim=True)
77
+ # waveform = torchaudio.transforms.Resample(
78
+ # orig_freq=sample_rate, new_freq=16000)(waveform)
79
+ # utils_file.makedir_for_file(output_wav_path)
80
+ # torchaudio.save(output_wav_path, waveform, 16000)
81
+ #
82
+ # def true_decode_fuc(input_wav_path, input_prompt):
83
+ # # input_prompt = TASK_PROMPT_MAPPING.get(input_prompt, "未知任务类型")
84
+ # print(f"wav_path: {input_wav_path}, prompt:{input_prompt}")
85
+ # timestamp_ms = int(time.time() * 1000)
86
+ # now_file_tmp_path_resample = f'/home/xlgeng/.cache/.temp/{timestamp_ms}_resample.wav'
87
+ # do_resample(input_wav_path, now_file_tmp_path_resample)
88
+ # # tmp_vad_path = f'/home/xlgeng/.cache/.temp/{timestamp_ms}_vad.wav'
89
+ # # remove_silence_torchaudio_ends(now_file_tmp_path_resample, tmp_vad_path)
90
+ # # input_wav_path = tmp_vad_path
91
+ # input_wav_path = now_file_tmp_path_resample
92
+ # waveform, sample_rate = torchaudio.load(input_wav_path)
93
+ # waveform = waveform.squeeze(0) # (channel=1, sample) -> (sample,)
94
+ # print(f'wavform shape: {waveform.shape}, sample_rate: {sample_rate}')
95
+ # window = torch.hann_window(400)
96
+ # stft = torch.stft(waveform,
97
+ # 400,
98
+ # 160,
99
+ # window=window,
100
+ # return_complex=True)
101
+ # magnitudes = stft[..., :-1].abs() ** 2
102
+ #
103
+ # filters = torch.from_numpy(
104
+ # librosa.filters.mel(sr=sample_rate,
105
+ # n_fft=400,
106
+ # n_mels=80))
107
+ # mel_spec = filters @ magnitudes
108
+ #
109
+ # # NOTE(xcsong): https://github.com/openai/whisper/discussions/269
110
+ # log_spec = torch.clamp(mel_spec, min=1e-10).log10()
111
+ # log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
112
+ # log_spec = (log_spec + 4.0) / 4.0
113
+ # feat = log_spec.transpose(0, 1)
114
+ # feat_lens = torch.tensor([feat.shape[0]], dtype=torch.int64).to(gpu_id)
115
+ # feat = feat.unsqueeze(0).to(gpu_id)
116
+ # # feat = feat.half()
117
+ # # feat_lens = feat_lens.half()
118
+ # res_text = model.generate(wavs=feat, wavs_len=feat_lens, prompt=input_prompt)[0]
119
+ # print("耿雪龙哈哈:", res_text)
120
+ # return res_text, now_file_tmp_path_resample
121
 
122
+ def do_decode(input_wav_path, input_prompt):
123
+ print(f'input_wav_path= {input_wav_path}, input_prompt= {input_prompt}')
124
+ # 省略处理逻辑
125
+ # output_res, now_file_tmp_path_resample= true_decode_fuc(input_wav_path, input_prompt)
126
+ output_res = f"耿雪龙哈哈:测试结果, input_wav_path= {input_wav_path}, input_prompt= {input_prompt}"
127
+ return output_res
128
 
129
+ def save_to_jsonl(if_correct, wav, prompt, res):
130
+ data = {
131
+ "if_correct": if_correct,
132
+ "wav": wav,
133
+ "task": prompt,
134
+ "res": res
135
+ }
136
+ with open("results.jsonl", "a", encoding="utf-8") as f:
137
+ f.write(json.dumps(data, ensure_ascii=False) + "\n")
138
 
139
+ def handle_submit(input_wav_path, input_prompt):
140
+ output_res = do_decode(input_wav_path, input_prompt)
141
+ return output_res
142
 
143
+ def download_audio(input_wav_path):
144
+ if input_wav_path:
145
+ # 返回文件路径供下载
146
+ return input_wav_path
147
+ else:
148
+ return None
 
 
149
 
150
+ # 创建Gradio界面
151
+ with gr.Blocks(css=custom_css) as demo:
152
+ # 添加标题
153
+ gr.Markdown(
154
+ f"""
155
+ <div style="display: flex; align-items: center; justify-content: center; text-align: center;">
156
+ <h1 style="font-family: 'Arial', sans-serif; color: #014377; font-size: 32px; margin-bottom: 0; display: inline-block; vertical-align: middle;">
157
+ OSUM Speech Understanding Model Test
158
+ </h1>
159
+ </div>
160
+ """
161
+ )
162
 
163
+ # 添加音频输入和任务选择
164
+ with gr.Row():
165
+ with gr.Column(scale=1):
166
+ audio_input = gr.Audio(label="录音", type="filepath")
167
+ with gr.Column(scale=1, min_width=300): # 给输出框设置最小宽度,确保等高对齐
168
+ output_text = gr.Textbox(label="输出结果", lines=8, placeholder="生成的结果将显示在这里...", interactive=False)
169
 
170
+ # 添加任务选择和自定义输入框
171
+ with gr.Row():
172
+ task_dropdown = gr.Dropdown(
173
+ label="任务",
174
+ choices=list(TASK_PROMPT_MAPPING.keys()) + ["自主输入文本"], # 新增选项
175
+ value="ASR (Automatic Speech Recognition)"
176
+ )
177
+ custom_prompt_input = gr.Textbox(label="自定义任务提示", placeholder="请输入自定义任务提示...", visible=False) # 新增文本输入框
178
+
179
+ # 添加按钮(下载按钮在左边,开始处理按钮在右边)
180
+ with gr.Row():
181
+ download_button = gr.DownloadButton("下载音频", variant="secondary", elem_classes=["button-height", "download-button"])
182
+ submit_button = gr.Button("开始处理", variant="primary", elem_classes=["button-height", "submit-button"])
183
+
184
+ # 添加确认组件
185
+ with gr.Row(visible=False) as confirmation_row:
186
+ gr.Markdown("请判断结果是否正确:")
187
+ confirmation_buttons = gr.Radio(
188
+ choices=["正确", "错误"],
189
+ label="",
190
+ interactive=True,
191
+ container=False,
192
+ elem_classes="confirmation-buttons"
193
+ )
194
+ save_button = gr.Button("提交反馈", variant="secondary")
195
+
196
+ # 添加底部内容
197
+ with gr.Row():
198
+ # 底部内容容器
199
+ with gr.Column(scale=1, min_width=800): # 设置最小宽度以确保内容居中
200
+ gr.Markdown(
201
+ f"""
202
+ <div style="position: fixed; bottom: 20px; left: 50%; transform: translateX(-50%); display: flex; align-items: center; justify-content: center; gap: 20px;">
203
+ <div style="text-align: center;">
204
+ <p style="margin: 0;"><strong>Audio, Speech and Language Processing Group (ASLP@NPU),</strong></p>
205
+ <p style="margin: 0;"><strong>Northwestern Polytechnical University</strong></p>
206
+ </div>
207
+ <img src="data:image/png;base64,{encoded_string}" alt="OSUM Logo" style="height: 80px; width: auto;">
208
+ </div>
209
+ """
210
+ )
211
+
212
+ # 绑定事件
213
+ def show_confirmation(output_res, input_wav_path, input_prompt):
214
+ return gr.update(visible=True), output_res, input_wav_path, input_prompt
215
+
216
+ def save_result(if_correct, wav, prompt, res):
217
+ save_to_jsonl(if_correct, wav, prompt, res)
218
+ return gr.update(visible=False)
219
+
220
+ def handle_submit(input_wav_path, task_choice, custom_prompt):
221
+ if task_choice == "自主输入文本":
222
+ input_prompt = custom_prompt # 使用用户输入的自定义文本
223
+ else:
224
+ input_prompt = TASK_PROMPT_MAPPING.get(task_choice, "未知任务类型") # 使用预定义的提示
225
+ output_res = do_decode(input_wav_path, input_prompt)
226
+ return output_res
227
+
228
+ task_dropdown.change(
229
+ fn=lambda choice: gr.update(visible=choice == "自主输入文本"),
230
+ inputs=task_dropdown,
231
+ outputs=custom_prompt_input
232
+ )
233
+
234
+ submit_button.click(
235
+ fn=handle_submit,
236
+ inputs=[audio_input, task_dropdown, custom_prompt_input],
237
+ outputs=output_text
238
+ ).then(
239
+ fn=show_confirmation,
240
+ inputs=[output_text, audio_input, task_dropdown],
241
+ outputs=[confirmation_row, output_text, audio_input, task_dropdown]
242
+ )
243
+
244
+ download_button.click(
245
+ fn=download_audio,
246
+ inputs=[audio_input],
247
+ outputs=[download_button] # 输出到 download_button
248
+ )
249
+
250
+ save_button.click(
251
+ fn=save_result,
252
+ inputs=[confirmation_buttons, audio_input, task_dropdown, output_text],
253
+ outputs=confirmation_row
254
+ )
255
+
256
+ if __name__== "__main__":
257
+ demo.launch()
app_old.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+
4
+ """
5
+ For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
+ """
7
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
+
9
+
10
+ def respond(
11
+ message,
12
+ history: list[tuple[str, str]],
13
+ system_message,
14
+ max_tokens,
15
+ temperature,
16
+ top_p,
17
+ ):
18
+ messages = [{"role": "system", "content": system_message}]
19
+
20
+ for val in history:
21
+ if val[0]:
22
+ messages.append({"role": "user", "content": val[0]})
23
+ if val[1]:
24
+ messages.append({"role": "assistant", "content": val[1]})
25
+
26
+ messages.append({"role": "user", "content": message})
27
+
28
+ response = ""
29
+
30
+ for message in client.chat_completion(
31
+ messages,
32
+ max_tokens=max_tokens,
33
+ stream=True,
34
+ temperature=temperature,
35
+ top_p=top_p,
36
+ ):
37
+ token = message.choices[0].delta.content
38
+
39
+ response += token
40
+ yield response
41
+
42
+
43
+ """
44
+ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
+ """
46
+ demo = gr.ChatInterface(
47
+ respond,
48
+ additional_inputs=[
49
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
+ gr.Slider(
53
+ minimum=0.1,
54
+ maximum=1.0,
55
+ value=0.95,
56
+ step=0.05,
57
+ label="Top-p (nucleus sampling)",
58
+ ),
59
+ ],
60
+ )
61
+
62
+
63
+ if __name__ == "__main__":
64
+ demo.launch()