tomxxie commited on
Commit
66817ed
·
1 Parent(s): d845e75

适配zeroGPU

Browse files
Files changed (4) hide show
  1. .idea/OSUM.iml +1 -1
  2. .idea/misc.xml +1 -1
  3. app.py +78 -84
  4. 实验室.png → lab.png +0 -0
.idea/OSUM.iml CHANGED
@@ -4,7 +4,7 @@
4
  <content url="file://$MODULE_DIR$">
5
  <excludeFolder url="file://$MODULE_DIR$/venv" />
6
  </content>
7
- <orderEntry type="inheritedJdk" />
8
  <orderEntry type="sourceFolder" forTests="false" />
9
  </component>
10
  <component name="PyDocumentationSettings">
 
4
  <content url="file://$MODULE_DIR$">
5
  <excludeFolder url="file://$MODULE_DIR$/venv" />
6
  </content>
7
+ <orderEntry type="jdk" jdkName="k2_gxl" jdkType="Python SDK" />
8
  <orderEntry type="sourceFolder" forTests="false" />
9
  </component>
10
  <component name="PyDocumentationSettings">
.idea/misc.xml CHANGED
@@ -3,5 +3,5 @@
3
  <component name="Black">
4
  <option name="sdkName" value="Python 3.12 (OSUM)" />
5
  </component>
6
- <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 (OSUM)" project-jdk-type="Python SDK" />
7
  </project>
 
3
  <component name="Black">
4
  <option name="sdkName" value="Python 3.12 (OSUM)" />
5
  </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="k2_gxl" project-jdk-type="Python SDK" />
7
  </project>
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import base64
2
  import json
3
  import time
 
 
4
  import spaces
5
 
6
  import gradio as gr
@@ -9,19 +11,18 @@ import os
9
  import sys
10
 
11
 
12
- # sys.path.insert(0, '../../../../')
13
- # from gxl_ai_utils.utils import utils_file
14
- # from wenet.utils.init_tokenizer import init_tokenizer
15
- # from gxl_ai_utils.config.gxl_config import GxlNode
16
- # from wenet.utils.init_model import init_model
17
  import logging
18
- # import librosa
19
- # import torch
20
- # import torchaudio
21
- # import numpy as np
22
 
23
  # 将图片转换为 Base64
24
- with open("./实验室.png", "rb") as image_file:
25
  encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
26
 
27
  # with open("./cat.jpg", "rb") as image_file:
@@ -44,81 +45,74 @@ TASK_PROMPT_MAPPING = {
44
  "STTC (Speech to Text Chat)": "首先将语音转录为文字,然后对语音内容进行回复,转录和文字之间使用<开始回答>分割。"
45
  }
46
 
47
- gpu_id = 4
48
- # def init_model_my():
49
- # logging.basicConfig(level=logging.DEBUG,
50
- # format='%(asctime)s %(levelname)s %(message)s')
51
- # config_path = "/home/node54_tmpdata/xlgeng/code/wenet_undersdand_and_speech_xlgeng/examples/wenetspeech/whisper/exp/update_data/epoch_1_with_token/epoch_11.yaml"
52
- # #config_path = "/home/work_nfs15/asr_data/ckpt/understanding_model/step_24999.yaml"
53
- #
54
- # checkpoint_path = "/home/node54_tmpdata/xlgeng/code/wenet_undersdand_and_speech_xlgeng/examples/wenetspeech/whisper/exp/update_data/epoch_1_with_token/epoch_11.pt"
55
- # checkpoint_path = "/home/work_nfs15/asr_data/ckpt/understanding_model/epoch4/step_21249.pt"
56
- # checkpoint_path = "/home/work_nfs15/asr_data/ckpt/understanding_model/epoch_13_with_asr-chat_full_data/step_32499/step_32499.pt"
57
- # args = GxlNode({
58
- # "checkpoint": checkpoint_path,
59
- # })
60
- # configs = utils_file.load_dict_from_yaml(config_path)
61
- # model, configs = init_model(args, configs)
62
- # model = model.cuda(gpu_id)
63
- # tokenizer = init_tokenizer(configs)
64
- # print(model)
65
- # return model, tokenizer
66
- #
67
  # model, tokenizer = init_model_my()
68
- #
69
- # def do_resample(input_wav_path, output_wav_path):
70
- # """"""
71
- # print(f'input_wav_path: {input_wav_path}, output_wav_path: {output_wav_path}')
72
- # waveform, sample_rate = torchaudio.load(input_wav_path)
73
- # # 检查音频的维度
74
- # num_channels = waveform.shape[0]
75
- # # 如果音频是多通道的,则进行通道平均
76
- # if num_channels > 1:
77
- # waveform = torch.mean(waveform, dim=0, keepdim=True)
78
- # waveform = torchaudio.transforms.Resample(
79
- # orig_freq=sample_rate, new_freq=16000)(waveform)
80
- # utils_file.makedir_for_file(output_wav_path)
81
- # torchaudio.save(output_wav_path, waveform, 16000)
82
- #
83
- # def true_decode_fuc(input_wav_path, input_prompt):
84
- # # input_prompt = TASK_PROMPT_MAPPING.get(input_prompt, "未知任务类型")
85
- # print(f"wav_path: {input_wav_path}, prompt:{input_prompt}")
86
- # timestamp_ms = int(time.time() * 1000)
87
- # now_file_tmp_path_resample = f'/home/xlgeng/.cache/.temp/{timestamp_ms}_resample.wav'
88
- # do_resample(input_wav_path, now_file_tmp_path_resample)
89
- # # tmp_vad_path = f'/home/xlgeng/.cache/.temp/{timestamp_ms}_vad.wav'
90
- # # remove_silence_torchaudio_ends(now_file_tmp_path_resample, tmp_vad_path)
91
- # # input_wav_path = tmp_vad_path
92
- # input_wav_path = now_file_tmp_path_resample
93
- # waveform, sample_rate = torchaudio.load(input_wav_path)
94
- # waveform = waveform.squeeze(0) # (channel=1, sample) -> (sample,)
95
- # print(f'wavform shape: {waveform.shape}, sample_rate: {sample_rate}')
96
- # window = torch.hann_window(400)
97
- # stft = torch.stft(waveform,
98
- # 400,
99
- # 160,
100
- # window=window,
101
- # return_complex=True)
102
- # magnitudes = stft[..., :-1].abs() ** 2
103
- #
104
- # filters = torch.from_numpy(
105
- # librosa.filters.mel(sr=sample_rate,
106
- # n_fft=400,
107
- # n_mels=80))
108
- # mel_spec = filters @ magnitudes
109
- #
110
- # # NOTE(xcsong): https://github.com/openai/whisper/discussions/269
111
- # log_spec = torch.clamp(mel_spec, min=1e-10).log10()
112
- # log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
113
- # log_spec = (log_spec + 4.0) / 4.0
114
- # feat = log_spec.transpose(0, 1)
115
- # feat_lens = torch.tensor([feat.shape[0]], dtype=torch.int64).to(gpu_id)
116
- # feat = feat.unsqueeze(0).to(gpu_id)
117
- # # feat = feat.half()
118
- # # feat_lens = feat_lens.half()
119
- # res_text = model.generate(wavs=feat, wavs_len=feat_lens, prompt=input_prompt)[0]
120
- # print("耿雪龙哈哈:", res_text)
121
- # return res_text, now_file_tmp_path_resample
122
  @spaces.GPU
123
  def do_decode(input_wav_path, input_prompt):
124
  print(f'input_wav_path= {input_wav_path}, input_prompt= {input_prompt}')
 
1
  import base64
2
  import json
3
  import time
4
+ from types import SimpleNamespace
5
+
6
  import spaces
7
 
8
  import gradio as gr
 
11
  import sys
12
 
13
 
14
+ sys.path.insert(0, './')
15
+ from gxl_ai_utils.utils import utils_file
16
+ from wenet.utils.init_tokenizer import init_tokenizer
17
+ from wenet.utils.init_model import init_model
 
18
  import logging
19
+ import librosa
20
+ import torch
21
+ import torchaudio
22
+ import numpy as np
23
 
24
  # 将图片转换为 Base64
25
+ with open("lab.png", "rb") as image_file:
26
  encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
27
 
28
  # with open("./cat.jpg", "rb") as image_file:
 
45
  "STTC (Speech to Text Chat)": "首先将语音转录为文字,然后对语音内容进行回复,转录和文字之间使用<开始回答>分割。"
46
  }
47
 
48
+ def init_model_my():
49
+ logging.basicConfig(level=logging.DEBUG,
50
+ format='%(asctime)s %(levelname)s %(message)s')
51
+ config_path = "/home/node54_tmpdata/xlgeng/code/wenet_undersdand_and_speech_xlgeng/examples/wenetspeech/whisper/exp/update_data/epoch_1_with_token/epoch_11.yaml"
52
+ checkpoint_path = "/home/work_nfs15/asr_data/ckpt/understanding_model/epoch_13_with_asr-chat_full_data/step_32499/step_32499.pt"
53
+ args = SimpleNamespace(**{
54
+ "checkpoint": checkpoint_path,
55
+ })
56
+ configs = utils_file.load_dict_from_yaml(config_path)
57
+ model, configs = init_model(args, configs)
58
+ model = model.cuda()
59
+ tokenizer = init_tokenizer(configs)
60
+ print(model)
61
+ return model, tokenizer
62
+
 
 
 
 
 
63
  # model, tokenizer = init_model_my()
64
+ print("model init success")
65
+ def do_resample(input_wav_path, output_wav_path):
66
+ """"""
67
+ print(f'input_wav_path: {input_wav_path}, output_wav_path: {output_wav_path}')
68
+ waveform, sample_rate = torchaudio.load(input_wav_path)
69
+ # 检查音频的维度
70
+ num_channels = waveform.shape[0]
71
+ # 如果音频是多通道的,则进行通道平均
72
+ if num_channels > 1:
73
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
74
+ waveform = torchaudio.transforms.Resample(
75
+ orig_freq=sample_rate, new_freq=16000)(waveform)
76
+ utils_file.makedir_for_file(output_wav_path)
77
+ torchaudio.save(output_wav_path, waveform, 16000)
78
+
79
+ def true_decode_fuc(input_wav_path, input_prompt):
80
+ # input_prompt = TASK_PROMPT_MAPPING.get(input_prompt, "未知任务类型")
81
+ print(f"wav_path: {input_wav_path}, prompt:{input_prompt}")
82
+ timestamp_ms = int(time.time() * 1000)
83
+ now_file_tmp_path_resample = f'/home/xlgeng/.cache/.temp/{timestamp_ms}_resample.wav'
84
+ do_resample(input_wav_path, now_file_tmp_path_resample)
85
+ input_wav_path = now_file_tmp_path_resample
86
+ waveform, sample_rate = torchaudio.load(input_wav_path)
87
+ waveform = waveform.squeeze(0) # (channel=1, sample) -> (sample,)
88
+ print(f'wavform shape: {waveform.shape}, sample_rate: {sample_rate}')
89
+ window = torch.hann_window(400)
90
+ stft = torch.stft(waveform,
91
+ 400,
92
+ 160,
93
+ window=window,
94
+ return_complex=True)
95
+ magnitudes = stft[..., :-1].abs() ** 2
96
+
97
+ filters = torch.from_numpy(
98
+ librosa.filters.mel(sr=sample_rate,
99
+ n_fft=400,
100
+ n_mels=80))
101
+ mel_spec = filters @ magnitudes
102
+
103
+ # NOTE(xcsong): https://github.com/openai/whisper/discussions/269
104
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
105
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
106
+ log_spec = (log_spec + 4.0) / 4.0
107
+ feat = log_spec.transpose(0, 1)
108
+ feat_lens = torch.tensor([feat.shape[0]], dtype=torch.int64).cuda()
109
+ feat = feat.unsqueeze(0).cuda()
110
+ # feat = feat.half()
111
+ # feat_lens = feat_lens.half()
112
+ model = None
113
+ res_text = model.generate(wavs=feat, wavs_len=feat_lens, prompt=input_prompt)[0]
114
+ print("耿雪龙哈哈:", res_text)
115
+ return res_text, now_file_tmp_path_resample
 
 
116
  @spaces.GPU
117
  def do_decode(input_wav_path, input_prompt):
118
  print(f'input_wav_path= {input_wav_path}, input_prompt= {input_prompt}')
实验室.png → lab.png RENAMED
File without changes